diff --git a/lavis/__init__.py b/lavis/__init__.py
index ab17686..9ea82d3 100644
--- a/lavis/__init__.py
+++ b/lavis/__init__.py
@@ -24,7 +24,8 @@ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
 registry.register_path("library_root", root_dir)
 repo_root = os.path.join(root_dir, "..")
 registry.register_path("repo_root", repo_root)
-cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
+# cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
+cache_root = default_cfg.env.cache_root
 registry.register_path("cache_root", cache_root)
 
 registry.register("MAX_INT", sys.maxsize)
diff --git a/lavis/common/utils.py b/lavis/common/utils.py
index 93b93c9..cf8482b 100644
--- a/lavis/common/utils.py
+++ b/lavis/common/utils.py
@@ -18,6 +18,7 @@ import urllib.error
 import urllib.request
 from typing import Optional
 from urllib.parse import urlparse
+from pytz import timezone
 
 import numpy as np
 import pandas as pd
@@ -37,7 +38,11 @@ from torchvision.datasets.utils import (
 def now():
     from datetime import datetime
 
-    return datetime.now().strftime("%Y%m%d%H%M")[:-1]
+    # return datetime.now().strftime("%Y%m%d%H%M")[:-1]
+    fmt = '%Y_%m_%d_%H_%M_%S'
+    # EST5EDT, Asia/Calcutta
+    job_id = str(datetime.now(timezone('PST8PDT')).strftime(fmt))
+    return job_id
 
 
 def is_url(url_or_filename):
diff --git a/lavis/configs/datasets/vg/defaults_caption_instruct.yaml b/lavis/configs/datasets/vg/defaults_caption_instruct.yaml
index 8015e94..7460ab1 100644
--- a/lavis/configs/datasets/vg/defaults_caption_instruct.yaml
+++ b/lavis/configs/datasets/vg/defaults_caption_instruct.yaml
@@ -31,4 +31,4 @@ datasets:
           url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/visual_genome/vg_caption.json
           storage: vg/annotations/vg_caption.json
       images:
-        storage: /export/share/datasets/vision/visual-genome/ #vg/images/
+        storage: /export/share/datasets/vision/visual-genome/ # vg/images/
\ No newline at end of file
diff --git a/lavis_ckd/configs/datasets/vg/defaults_ckd.yaml b/lavis/configs/datasets/vg/defaults_ckd.yaml
new file mode 100644
index 0000000..a57a523
--- /dev/null
+++ b/lavis/configs/datasets/vg/defaults_ckd.yaml
@@ -0,0 +1,18 @@
+
+
+datasets:
+  vg_ckd:
+    # data_dir: ${env.data_dir}/datasets
+    data_type: images # [images|videos|features]
+
+    build_info:
+      # Be careful not to append minus sign (-) before split to avoid itemizing
+      annotations:
+        train:
+          url: ''
+          storage: vg/annotations/vg_objects_hallucinated_desc.json
+      images:
+        storage: vg/images/
diff --git a/lavis/configs/default.yaml b/lavis/configs/default.yaml
index f58d32e..17ba784 100644
--- a/lavis/configs/default.yaml
+++ b/lavis/configs/default.yaml
@@ -1,10 +1,5 @@
- # Copyright (c) 2022, salesforce.com, inc.
- # All rights reserved.
- # SPDX-License-Identifier: BSD-3-Clause
- # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-
 env:
   # For default users
   # cache_root: "cache"
   # For internal use with persistent storage
-  cache_root: "/export/home/.cache/lavis"
+  cache_root: "cache/" # TODO: change it based on your cache/data source
diff --git a/lavis_ckd/configs/models/blip2/blip2_instruct_ckd_lora_vicuna7b.yaml b/lavis/configs/models/blip2/blip2_instruct_ckd_lora_vicuna7b.yaml
new file mode 100644
index 0000000..c4ae45f
--- /dev/null
+++ b/lavis/configs/models/blip2/blip2_instruct_ckd_lora_vicuna7b.yaml
@@ -0,0 +1,43 @@
+
+
+model:
+  arch: blip2_vicuna_instruct_ckd
+  load_finetuned: False
+  load_pretrained: True
+
+  pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth"
+  finetuned: ""
+
+  # vit encoder
+  image_size: 224
+  drop_path_rate: 0
+  use_grad_checkpoint: False
+  vit_precision: "fp16"
+  freeze_vit: True
+
+  # Q-Former
+  num_query_token: 32
+
+  # path to Vicuna checkpoint
+  llm_model: "lmsys/vicuna-7b-v1.1" # "./llm/vicuna-7b"
+
+  # generation configs
+  prompt: ""
+
+
+preprocess:
+    vis_processor:
+        train:
+          name: "blip2_image_train"
+          image_size: 224
+        eval:
+          name: "blip_image_eval"
+          image_size: 224
+    text_processor:
+        train:
+          name: "blip_caption_ckd"
+        eval:
+          name: "blip_caption_ckd"
diff --git a/lavis_ckd/configs/models/blip2/blip2_instruct_ckd_vicuna13b.yaml b/lavis/configs/models/blip2/blip2_instruct_ckd_vicuna13b.yaml
new file mode 100644
index 0000000..2c6d315
--- /dev/null
+++ b/lavis/configs/models/blip2/blip2_instruct_ckd_vicuna13b.yaml
@@ -0,0 +1,43 @@
+
+
+model:
+  arch: blip2_vicuna_instruct_ckd
+  load_finetuned: False
+  load_pretrained: True
+
+  pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna13b_trimmed.pth"
+  finetuned: ""
+
+  # vit encoder
+  image_size: 224
+  drop_path_rate: 0
+  use_grad_checkpoint: False
+  vit_precision: "fp16"
+  freeze_vit: True
+
+  # Q-Former
+  num_query_token: 32
+
+  # path to Vicuna checkpoint
+  llm_model: "lmsys/vicuna-13b-v1.1" # "./llm/vicuna-13b"
+
+  # generation configs
+  prompt: ""
+
+
+preprocess:
+    vis_processor:
+        train:
+          name: "blip2_image_train"
+          image_size: 224
+        eval:
+          name: "blip_image_eval"
+          image_size: 224
+    text_processor:
+        train:
+          name: "blip_caption_ckd"
+        eval:
+          name: "blip_caption_ckd"
diff --git a/lavis_ckd/configs/models/blip2/blip2_instruct_ckd_vicuna7b.yaml b/lavis/configs/models/blip2/blip2_instruct_ckd_vicuna7b.yaml
new file mode 100644
index 0000000..c4ae45f
--- /dev/null
+++ b/lavis/configs/models/blip2/blip2_instruct_ckd_vicuna7b.yaml
@@ -0,0 +1,43 @@
+
+
+model:
+  arch: blip2_vicuna_instruct_ckd
+  load_finetuned: False
+  load_pretrained: True
+
+  pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/InstructBLIP/instruct_blip_vicuna7b_trimmed.pth"
+  finetuned: ""
+
+  # vit encoder
+  image_size: 224
+  drop_path_rate: 0
+  use_grad_checkpoint: False
+  vit_precision: "fp16"
+  freeze_vit: True
+
+  # Q-Former
+  num_query_token: 32
+
+  # path to Vicuna checkpoint
+  llm_model: "lmsys/vicuna-7b-v1.1" # "./llm/vicuna-7b"
+
+  # generation configs
+  prompt: ""
+
+
+preprocess:
+    vis_processor:
+        train:
+          name: "blip2_image_train"
+          image_size: 224
+        eval:
+          name: "blip_image_eval"
+          image_size: 224
+    text_processor:
+        train:
+          name: "blip_caption_ckd"
+        eval:
+          name: "blip_caption_ckd"
diff --git a/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml b/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml
index 1036539..b25c941 100644
--- a/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml
+++ b/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml
@@ -22,7 +22,7 @@ model:
   num_query_token: 32
 
   # path to Vicuna checkpoint
-  llm_model: "./llm/vicuna-13b"
+  llm_model: "lmsys/vicuna-13b-v1.1" # "./llm/vicuna-13b"
 
   # generation configs
   prompt: ""
diff --git a/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml b/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml
index af67777..724ee21 100644
--- a/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml
+++ b/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml
@@ -22,7 +22,7 @@ model:
   num_query_token: 32
 
   # path to Vicuna checkpoint
-  llm_model: "./llm/vicuna-7b"
+  llm_model: "lmsys/vicuna-7b-v1.1" # "./llm/vicuna-7b"
 
   # generation configs
   prompt: ""
diff --git a/lavis/datasets/builders/__init__.py b/lavis/datasets/builders/__init__.py
index baabb36..e97a687 100644
--- a/lavis/datasets/builders/__init__.py
+++ b/lavis/datasets/builders/__init__.py
@@ -73,6 +73,7 @@ from lavis.datasets.builders.vqa_builder import (
     AOKVQAInstructBuilder,
     VGVQABuilder,
     VGVQAInstructBuilder,
+    VGCKDBuilder,
     GQABuilder,
     GQAInstructBuilder,
     IconQABuilder,
@@ -205,6 +206,7 @@ __all__ = [
     "CharadeCaptionInstructBuilder",
     "COCOVQAInstructBuilder",
     "VGVQAInstructBuilder",
+    "VGCKDBuilder",
     "GQAInstructBuilder",
     "IconQAInstructBuilder",
     "SNLIVisualEntailmentInstructBuilder",
diff --git a/lavis/datasets/builders/vqa_builder.py b/lavis/datasets/builders/vqa_builder.py
index f2684bf..750abc7 100644
--- a/lavis/datasets/builders/vqa_builder.py
+++ b/lavis/datasets/builders/vqa_builder.py
@@ -10,7 +10,7 @@ from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder
 from lavis.common.registry import registry
 from lavis.datasets.datasets.aok_vqa_datasets import AOKVQADataset, AOKVQAEvalDataset, AOKVQAInstructDataset
 from lavis.datasets.datasets.coco_vqa_datasets import COCOVQADataset, COCOVQAEvalDataset, COCOVQAInstructDataset
-from lavis.datasets.datasets.vg_vqa_datasets import VGVQADataset, VGVQAInstructDataset
+from lavis.datasets.datasets.vg_vqa_datasets import VGVQADataset, VGVQAInstructDataset, VGDatasetCKD
 from lavis.datasets.datasets.gqa_datasets import GQADataset, GQAEvalDataset, GQAInstructDataset
 from lavis.datasets.datasets.iconqa_datasets import IconQADataset, IconQAEvalDataset, IconQAInstructDataset
 from lavis.datasets.datasets.ocr_datasets import OCRVQADataset, OCRVQAInstructDataset
@@ -52,6 +52,12 @@ class VGVQAInstructBuilder(BaseDatasetBuilder):
         "default": "configs/datasets/vg/defaults_vqa_instruct.yaml"}
 
 
+@registry.register_builder("vg_ckd")
+class VGCKDBuilder(BaseDatasetBuilder):
+    train_dataset_cls = VGDatasetCKD
+    DATASET_CONFIG_DICT = {"default": "configs/datasets/vg/defaults_ckd.yaml"}
+
+
 @registry.register_builder("ok_vqa")
 class OKVQABuilder(COCOVQABuilder):
     DATASET_CONFIG_DICT = {
diff --git a/lavis/datasets/datasets/aok_vqa_datasets.py b/lavis/datasets/datasets/aok_vqa_datasets.py
index 4306458..2fdedfd 100644
--- a/lavis/datasets/datasets/aok_vqa_datasets.py
+++ b/lavis/datasets/datasets/aok_vqa_datasets.py
@@ -73,9 +73,30 @@ class AOKVQAInstructDataset(AOKVQADataset):
         return data
 
     def collater(self, samples):
-        data = super().collater(samples)
-        data['text_output'] = data['answer']
-        return data
+        image_list, question_list, answer_list, weight_list = [], [], [], []
+        full_answer_list = []
+        num_answers = []
+
+        for sample in samples:
+            image_list.append(sample["image"])
+            question_list.append(sample["text_input"])
+            full_answer_list.append(sample["text_output"])
+
+            weight_list.extend(sample["weights"])
+
+            answers = sample["answers"]
+
+            answer_list.extend(answers)
+            num_answers.append(len(answers))
+
+        return {
+            "image": torch.stack(image_list, dim=0),
+            "text_input": question_list,
+            "answer": answer_list,
+            "text_output": full_answer_list,
+            "weight": torch.Tensor(weight_list),
+            "n_answers": torch.LongTensor(num_answers),
+        }
 
 
 class AOKVQAEvalDataset(VQAEvalDataset, __DisplMixin):
diff --git a/lavis/datasets/datasets/vg_vqa_datasets.py b/lavis/datasets/datasets/vg_vqa_datasets.py
index 85c8ee7..815e6fa 100644
--- a/lavis/datasets/datasets/vg_vqa_datasets.py
+++ b/lavis/datasets/datasets/vg_vqa_datasets.py
@@ -11,6 +11,7 @@ import random
 from PIL import Image
 
 from lavis.datasets.datasets.vqa_datasets import VQADataset
+import torch
 
 
 class VGVQADataset(VQADataset):
@@ -49,3 +50,45 @@ class VGVQAInstructDataset(VGVQADataset):
         data = super().collater(samples)
         data['text_output'] = data['answer']
         return data
+
+
+class VGDatasetCKD(VQADataset):
+    def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
+        super().__init__(vis_processor, text_processor, vis_root, ann_paths)
+
+    def __getitem__(self, index):
+        ann = self.annotation[index]
+
+        image_path = os.path.join(self.vis_root, ann["image"])
+        image = Image.open(image_path).convert("RGB")
+
+        image = self.vis_processor(image)
+        # question = self.text_processor(ann["question"])
+        question = "Describe this image in detail."
+        question = self.text_processor(question)
+
+        pos_descrition = self.text_processor(' '.join(ann["pos_description"]))
+        neg_descrition = self.text_processor(' '.join(ann["neg_description"]))
+
+        return {
+            "image": image,
+            "text_input": question,
+            "pos_descrition": pos_descrition,
+            "neg_descrition": neg_descrition,
+        }
+
+    def collater(self, samples):
+        image_list, question_list, pos_descrition_list, neg_descrition_list = [], [], [], []
+
+        for sample in samples:
+            image_list.append(sample["image"])
+            question_list.append(sample["text_input"])
+            pos_descrition_list.append(sample["pos_descrition"])
+            neg_descrition_list.append(sample["neg_descrition"])
+
+        return {
+            "image": torch.stack(image_list, dim=0),
+            "text_input": question_list,
+            "pos_descrition": pos_descrition_list,
+            "neg_descrition": neg_descrition_list,
+        }
diff --git a/lavis/models/__init__.py b/lavis/models/__init__.py
index 26ac9b2..3bce7ed 100644
--- a/lavis/models/__init__.py
+++ b/lavis/models/__init__.py
@@ -54,6 +54,9 @@ from lavis.models.gpt_models.gpt_dialogue import GPTDialogue
 
 from lavis.processors.base_processor import BaseProcessor
 
+from lavis.models.blip2_models.blip2_vicuna_instruct_ckd import Blip2VicunaInstructCKD
+from lavis.models.blip2_models.blip2_vicuna_instruct_ckd_lora import Blip2VicunaInstructCKDLoRA
+
 
 __all__ = [
     "load_model",
@@ -82,7 +85,7 @@ __all__ = [
     "Blip2OPT",
     "Blip2T5",
     "Blip2T5Instruct",
-    "Blip2VicunaInstruct",
+    "Blip2VicunaInstruct", "Blip2VicunaInstructCKD", "Blip2VicunaInstructCKDLoRA",
     "Blip2VicunaXInstruct",
     "PNPVQA",
     "Img2PromptVQA",
diff --git a/lavis/models/base_model.py b/lavis/models/base_model.py
index 50e7c43..cda61b7 100644
--- a/lavis/models/base_model.py
+++ b/lavis/models/base_model.py
@@ -92,6 +92,7 @@ class BaseModel(nn.Module):
             assert (
                 finetune_path is not None
             ), "Found load_finetuned is True, but finetune_path is None."
+            logging.info(f"loading finetuned weights from: {finetune_path}")
             self.load_checkpoint(url_or_filename=finetune_path)
         else:
             load_pretrained = cfg.get("load_pretrained", True)
@@ -99,6 +100,8 @@ class BaseModel(nn.Module):
                 # load pre-trained weights
                 pretrain_path = cfg.get("pretrained", None)
                 assert "Found load_finetuned is False, but pretrain_path is None."
+                logging.info(
+                    f"loading pretrained weights from: {pretrain_path}")
                 self.load_from_pretrained(
                     url_or_filename=pretrain_path, **kwargs)
 
diff --git a/lavis/models/blip2_models/blip2.py b/lavis/models/blip2_models/blip2.py
index 98a5071..de349dc 100644
--- a/lavis/models/blip2_models/blip2.py
+++ b/lavis/models/blip2_models/blip2.py
@@ -74,10 +74,10 @@ class Blip2Base(BaseModel):
             visual_encoder = create_eva_vit_g(
                 img_size, drop_path_rate, use_grad_checkpoint, precision
             )
-#         elif model_name == "eva2_clip_L":
-#             visual_encoder = create_eva2_vit_L(
-#                 img_size, drop_path_rate, use_grad_checkpoint, precision
-#             )
+        # elif model_name == "eva2_clip_L":
+        #     visual_encoder = create_eva2_vit_L(
+        #         img_size, drop_path_rate, use_grad_checkpoint, precision
+        #     )
         elif model_name == "clip_L":
             visual_encoder = create_clip_vit_L(
                 img_size, use_grad_checkpoint, precision)
diff --git a/lavis/models/blip2_models/blip2_t5_instruct.py b/lavis/models/blip2_models/blip2_t5_instruct.py
index 4aec003..073a8d1 100644
--- a/lavis/models/blip2_models/blip2_t5_instruct.py
+++ b/lavis/models/blip2_models/blip2_t5_instruct.py
@@ -123,6 +123,8 @@ class Blip2T5Instruct(Blip2Base):
         # print(samples["text_output"])
         # print('-----------------')
 
+        print(samples)
+
         image = samples["image"]
         with self.maybe_autocast():
             image_embeds = self.ln_vision(self.visual_encoder(image))
@@ -205,6 +207,7 @@ class Blip2T5Instruct(Blip2Base):
                 attention_mask=encoder_atts,
                 decoder_attention_mask=output_tokens.attention_mask,
                 return_dict=True,
+                # FIXME: targets shape is problematic -> [9, 4]; 2, 43, 2048
                 labels=targets,
             )
             loss = outputs.loss
diff --git a/lavis_ckd/models/blip2_models/blip2_t5_instruct_ckd.py b/lavis/models/blip2_models/blip2_t5_instruct_ckd.py
new file mode 100644
index 0000000..b4c08d3
--- /dev/null
+++ b/lavis/models/blip2_models/blip2_t5_instruct_ckd.py
@@ -0,0 +1,963 @@
+import logging
+import string
+import random
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch.cuda.amp import autocast as autocast
+from transformers import T5TokenizerFast
+
+from lavis.common.registry import registry
+from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
+from lavis.models.blip2_models.modeling_t5 import T5Config, T5ForConditionalGeneration
+from transformers.modeling_outputs import BaseModelOutput
+from torch.nn import CrossEntropyLoss
+
+
+
+
+@registry.register_model("blip2_t5_instruct_ckd")
+class Blip2T5InstructCKD(Blip2Base):
+    """
+    BLIP2 T5 model.
+    Supported model types:
+        - flant5xl
+        - flant5xxl
+    Usage:
+        >>> from lavis.models import load_model
+        >>> model = load_model("blip2_t5_kd", "flant5xl")
+    """
+
+    PRETRAINED_MODEL_CONFIG_DICT = {
+        "flant5xl": "configs/models/blip2/blip2_instruct_ckd_flant5xl.yaml",
+        "flant5xxl": "configs/models/blip2/blip2_instruct_ckd_flant5xxl.yaml",
+    }
+
+    def __init__(
+        self,
+        vit_model="eva_clip_g",
+        img_size=224,
+        drop_path_rate=0,
+        use_grad_checkpoint=False,
+        vit_precision="fp16",
+        freeze_vit=True,
+        num_query_token=32,
+        t5_model="google/flan-t5-xl",
+        prompt="",
+        max_txt_len=128,
+        max_output_txt_len=256,
+        apply_lemmatizer=False,
+        num_few_shot_examples=0,
+        few_shot_prob=0,
+        qformer_text_input=True,
+        kd_loss='ckd',
+        alpha=0.5,
+    ):
+        """
+        apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.
+        """
+        super().__init__()
+
+        assert kd_loss in ['kd', 'ckd']
+        self.kd_loss = kd_loss
+        self.alpha=alpha
+        self.tokenizer = self.init_tokenizer(truncation_side="left")
+
+        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
+            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
+        )
+        if freeze_vit:
+            for name, param in self.visual_encoder.named_parameters():
+                param.requires_grad = False
+            self.visual_encoder = self.visual_encoder.eval()
+            self.visual_encoder.train = disabled_train
+            logging.info("freeze vision encoder")
+        else:
+            logging.info("train vision encoder")
+
+        self.Qformer, self.query_tokens = self.init_Qformer(
+            num_query_token, self.visual_encoder.num_features
+        )
+
+        if not qformer_text_input:
+            self.Qformer.bert.embeddings.word_embeddings = None
+            self.Qformer.bert.embeddings.position_embeddings = None
+            for layer in self.Qformer.bert.encoder.layer:
+                layer.output = None
+                layer.intermediate = None
+        else:
+            self.Qformer.resize_token_embeddings(len(self.tokenizer))
+        self.Qformer.cls = None
+
+        self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='left')
+        self.t5_output_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='right')
+
+        t5_config = T5Config.from_pretrained(t5_model)
+        t5_config.dense_act_fn = "gelu"
+        self.t5_model = T5ForConditionalGeneration.from_pretrained(
+            t5_model, config=t5_config
+        )
+
+        for name, param in self.t5_model.named_parameters():
+            param.requires_grad = False
+            param.data = param.data.bfloat16()
+
+        self.t5_proj = nn.Linear(
+            self.Qformer.config.hidden_size, self.t5_model.config.hidden_size
+        )
+
+        self.max_txt_len = max_txt_len
+        self.max_output_txt_len = max_output_txt_len
+        self.prompt = prompt
+
+        self._apply_lemmatizer = apply_lemmatizer
+        self._lemmatizer = None
+
+        self.num_few_shot_examples = num_few_shot_examples
+        self.few_shot_prob = few_shot_prob
+
+        self.qformer_text_input = qformer_text_input
+
+        n_parameters_train = sum(p.numel() for p in self.parameters() if p.requires_grad)/ 1.e6
+        n_parameters_total = sum(p.numel() for p in self.parameters())/ 1.e6
+        logging.info(f"total trainable parameter {n_parameters_train} million - total parameter {n_parameters_total} million")
+
+    def concat_pos_neg(self,
+                    pos_ids, pos_atts,
+                    neg_ids, neg_atts):
+        # total_len = pos_ids.shape[1]+neg_ids.shape[1]
+        # input_part_targets_len = []
+        bs=pos_ids.size(0)
+        sign = []
+        llm_tokens = {"input_ids": [], "attention_mask": []}
+        for i in range(bs):
+            # this_input_ones = input_atts[i].sum()
+            # input_part_targets_len.append(this_input_ones)
+
+            pos_len = pos_atts[i].sum()
+            neg_len = neg_atts[i].sum()
+            llm_tokens['input_ids'].append(
+                torch.cat([
+                    pos_ids[i][:pos_len],
+                    neg_ids[i][:neg_len],
+                    # following are ignored parts
+                    pos_ids[i][pos_len:],
+                    neg_ids[i][neg_len:],
+                ])
+            )
+            llm_tokens['attention_mask'].append(
+                torch.cat([
+                    pos_atts[i][:pos_len],
+                    neg_atts[i][:neg_len],
+                    # following are ignored parts
+                    pos_atts[i][pos_len:],
+                    neg_atts[i][neg_len:],
+                ])
+            )
+            sign.append(torch.cat([
+                    torch.ones(pos_len)*1, # positive
+                    torch.ones(neg_len)*-1, # negative
+                    torch.ones((len(pos_atts[i])-pos_len)+
+                               (len(neg_atts[i])-neg_len))*-100, # pads ignored
+                ]))
+            
+        llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
+        llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask'])
+        sign = torch.stack(sign)
+        return llm_tokens, sign
+
+
+    
+    def forward(self, samples):
+        # print('-----------------')
+        # print(samples["text_input"])
+        # print(samples["text_output"])
+        # print('-----------------')
+        DEBUG = True if samples['epoch']==0 and samples['iters']==0 else False
+        use_negatives=False
+        if self.kd_loss.startswith('ckd'):
+            use_negatives=True
+
+        image = samples["image"]
+        with self.maybe_autocast():
+            image_embeds = self.ln_vision(self.visual_encoder(image))
+
+        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+        bs = image.size(0)
+        text_input = samples['text_input']
+        pos_descrition = samples['pos_descrition']
+        if use_negatives:
+            neg_descrition = samples['neg_descrition']
+        else:
+            neg_descrition = ['None']
+        
+        if DEBUG:
+            print(f"EPOCH {samples['epoch']}",  'text_input:', text_input[0])
+            print(f"EPOCH {samples['epoch']}",  
+                  'text_output: Positive:', pos_descrition[0]+' Negative: '+neg_descrition[0])
+
+        # tokenize
+        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+        if self.qformer_text_input:
+            text_Qformer = self.tokenizer(
+                text_input,
+                padding='longest',
+                truncation=True,
+                max_length=self.max_txt_len,
+                return_tensors="pt",
+            ).to(image.device)
+            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+            Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1)
+
+            query_output = self.Qformer.bert(
+                text_Qformer.input_ids,
+                attention_mask=Qformer_atts,
+                query_embeds=query_tokens,
+                encoder_hidden_states=image_embeds,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+        else:
+            query_output = self.Qformer.bert(
+                query_embeds=query_tokens,
+                encoder_hidden_states=image_embeds,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+
+        inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
+        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
+
+        # few-shot is ignored in our setup
+        fs_embeds, fs_atts = None, None
+        # if self.few_shot_prob > 0 and "few_shot_samples" in samples.keys():
+        #     fs_embeds, fs_atts = self.prepare_few_shot_embeds(samples['few_shot_samples'])
+
+        with self.maybe_autocast(dtype=torch.bfloat16):
+            input_tokens = self.t5_tokenizer(
+                text_input,
+                padding="longest",
+                truncation=True,
+                max_length=self.max_txt_len,
+                return_tensors="pt",
+            ).to(image.device)
+
+            pos_rat_tokens = self.t5_output_tokenizer(
+                pos_descrition,
+                padding="longest",
+                truncation=True,
+                max_length=self.max_output_txt_len,
+                return_tensors="pt",
+            ).to(image.device)
+            if use_negatives:
+                neg_rat_tokens = self.t5_output_tokenizer(
+                    neg_descrition,
+                    padding="longest",
+                    truncation=True,
+                    max_length=self.max_output_txt_len,
+                    return_tensors="pt",
+                ).to(image.device)
+
+            output_tokens = {"input_ids": [], "attention_mask": []}
+            if use_negatives:
+                output_tokens, sign = self.concat_pos_neg(
+                    pos_rat_tokens.input_ids,
+                    pos_rat_tokens.attention_mask,
+                    neg_rat_tokens.input_ids,
+                    neg_rat_tokens.attention_mask,
+                )
+                sign = sign.type(torch.long).to(image.device)
+            else:
+                output_tokens['input_ids'] = pos_rat_tokens.input_ids
+                output_tokens['attention_mask'] = pos_rat_tokens.attention_mask
+
+
+            encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
+
+            targets = output_tokens['input_ids'].masked_fill(
+                output_tokens['input_ids'] == self.t5_tokenizer.pad_token_id, -100
+            )
+
+            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
+            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
+
+            ####### no contrastive kd
+            if not use_negatives:
+                
+                outputs = self.t5_model(
+                    inputs_embeds=inputs_embeds,
+                    attention_mask=encoder_atts,
+                    decoder_attention_mask=output_tokens['attention_mask'],
+                    return_dict=True,
+                    return_lm_logits_seq_out=True, 
+                    labels=targets,
+                ) # loss, lm_logits, seq_out
+
+                loss = outputs[0]
+                return {"loss": loss, 'ce_loss': loss}
+
+            ####### contrastive kd 
+            
+            if self.kd_loss == 'ckd': 
+
+                labels = targets
+
+                # do not calculate loss for the negative desc
+                targets = targets.masked_fill(
+                    sign == -1, -100
+                    )
+
+                outputs = self.t5_model(
+                    inputs_embeds=inputs_embeds,
+                    attention_mask=encoder_atts,
+                    decoder_attention_mask=output_tokens['attention_mask'],
+                    return_dict=True,
+                    return_lm_logits_seq_out=True, 
+                    labels=targets,
+                ) # loss, lm_logits, seq_out
+
+                pos_loss, lm_logits, seq_out = outputs
+
+                # contrastive loss
+                pos_logits=[]
+                neg_targets=[]
+
+                for k in range(bs):
+                    # logit
+                    _logit_sign = sign[k].cpu()
+                    pos_len = (_logit_sign==1).sum().item()
+                    pos_start = _logit_sign.numpy().tolist().index(1) # first found location
+
+                    # target
+                    _label_sign = sign[k].cpu()
+                    neg_len = (_label_sign==-1).sum().item()
+                    neg_start = _label_sign.numpy().tolist().index(-1) # first found location
+                    
+                    # logit-target
+                    _len = min(pos_len, neg_len) # we are avoiding zero-padding and trimming
+                    _pos_logits = lm_logits[k, pos_start:pos_start+_len] 
+                    _neg_targets = labels[k, neg_start:neg_start+_len]
+                    # pad to make size same
+
+                    pos_logits.append(_pos_logits)
+                    neg_targets.append(_neg_targets)
+
+                pos_logits = torch.cat(pos_logits)
+                neg_targets = torch.cat(neg_targets)
+
+                neg_loss = F.nll_loss(torch.log(torch.clamp((1.0 - F.softmax(pos_logits)), min=1e-5)), neg_targets, reduction='mean')
+                loss = pos_loss * self.alpha + neg_loss * (1-self.alpha)
+
+                return {"loss": loss, 'pos_loss': pos_loss, 'neg_loss': neg_loss}
+
+
+    def pad_zero_to_match_shape(self, 
+                                tensor_a, # larger
+                                tensor_b, # smaller
+                                ):
+        
+        # tensor_a = torch.randn(32, 43)
+        # tensor_b = torch.randn(32, 33)
+        
+        assert tensor_a.shape[1]>tensor_b.shape[1], "first item should be the larger one"
+
+        tensor_b_padded = torch.zeros_like(tensor_a)
+        tensor_b_padded[:, :tensor_b.shape[1]] = tensor_b
+
+        return tensor_b_padded
+
+    # def prepare_few_shot_embeds(self, samples):
+    #     this_n_fs = random.choices(
+    #         list(range(self.num_few_shot_examples + 1)),
+    #         weights=[1 - self.few_shot_prob] + [self.few_shot_prob / self.num_few_shot_examples] * self.num_few_shot_examples
+    #     )[0]
+
+    #     if this_n_fs == 0:
+    #         return None, None
+
+    #     images = []
+    #     text_input = []
+    #     for sample in samples:
+    #         for n in range(this_n_fs):
+    #             images.append(sample['image'][n])
+    #             text_input.append(sample['text_input'][n])
+    #     images = torch.stack(images, dim=0)
+
+    #     image = images
+
+    #     with self.maybe_autocast():
+    #         image_embeds = self.ln_vision(self.visual_encoder(image))
+    #     image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
+    #         image.device
+    #     )
+
+    #     query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+    #     if self.qformer_text_input:
+    #         text_Qformer = self.tokenizer(
+    #             text_input,
+    #             padding='longest',
+    #             truncation=True,
+    #             max_length=self.max_txt_len,
+    #             return_tensors="pt",
+    #         ).to(image.device)
+    #         query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+    #         Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1)
+    #         query_output = self.Qformer.bert(
+    #             text_Qformer.input_ids,
+    #             attention_mask = Qformer_atts,
+    #             query_embeds=query_tokens,
+    #             encoder_hidden_states=image_embeds,
+    #             encoder_attention_mask=image_atts,
+    #             return_dict=True,
+    #         )
+    #     else:
+    #         query_output = self.Qformer.bert(
+    #             query_embeds=query_tokens,
+    #             encoder_hidden_states=image_embeds,
+    #             encoder_attention_mask=image_atts,
+    #             return_dict=True,
+    #         )
+
+    #     inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
+    #     atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
+
+    #     with self.maybe_autocast(dtype=torch.bfloat16):
+    #         input_tokens = self.t5_tokenizer(
+    #             text_input,
+    #             padding="longest",
+    #             truncation=True,
+    #             max_length=self.max_txt_len,
+    #             return_tensors="pt",
+    #         ).to(image.device)
+
+    #         encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
+
+    #         inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
+    #         inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
+
+    #     if this_n_fs > 1:
+    #         encoder_atts = encoder_atts.reshape(encoder_atts.size(0) // this_n_fs, encoder_atts.size(1) * this_n_fs)
+    #         inputs_embeds = inputs_embeds.reshape(inputs_embeds.size(0) // this_n_fs, inputs_embeds.size(1) * this_n_fs, inputs_embeds.size(2))
+
+    #     return inputs_embeds, encoder_atts
+
+    @torch.no_grad()
+    def generate(
+        self,
+        samples,
+        use_nucleus_sampling=False,
+        num_beams=5,
+        max_length=256,
+        min_length=1,
+        top_p=0.9,
+        repetition_penalty=1.5,
+        length_penalty=1.0,
+        num_captions=1,
+        temperature=1,
+    ):
+        if "prompt" in samples.keys():
+            prompt = samples["prompt"]
+        else:
+            prompt = self.prompt
+
+        image = samples["image"]
+
+        bs = image.size(0)
+
+        if isinstance(prompt, str):
+            prompt = [prompt] * bs
+        else:
+            assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
+
+        # For TextCaps
+        if "ocr_tokens" in samples.keys() and "{}" in prompt[0]:
+            prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)]
+
+        query_tokens = self.query_tokens.expand(bs, -1, -1)
+        if self.qformer_text_input:
+            # remove ocr tokens in q_former (for eval textvqa)
+            # qformer_prompt = prompt
+            # qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt]
+
+            text_Qformer = self.tokenizer(
+                prompt,
+                padding='longest',
+                truncation=True,
+                max_length=self.max_txt_len,
+                return_tensors="pt",
+            ).to(image.device)
+            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+            Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1)
+
+        # For video data
+        if image.dim() == 5:
+            inputs_t5, atts_t5 = [], []
+            for j in range(image.size(2)):
+                this_frame = image[:,:,j,:,:]
+                with self.maybe_autocast():
+                    frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
+                frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+                if self.qformer_text_input:
+                    frame_query_output = self.Qformer.bert(
+                        text_Qformer.input_ids,
+                        attention_mask = Qformer_atts,
+                        query_embeds=query_tokens,
+                        encoder_hidden_states=frame_embeds,
+                        encoder_attention_mask=frame_atts,
+                        return_dict=True,
+                    )
+                else:
+                    frame_query_output = self.Qformer.bert(
+                        query_embeds=query_tokens,
+                        encoder_hidden_states=frame_embeds,
+                        encoder_attention_mask=frame_atts,
+                        return_dict=True,
+                    )
+
+                frame_inputs_t5 = self.t5_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
+                frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
+                inputs_t5.append(frame_inputs_t5)
+                atts_t5.append(frame_atts_t5)
+            inputs_t5 = torch.cat(inputs_t5, dim=1)
+            atts_t5 = torch.cat(atts_t5, dim=1)
+        else:
+            with self.maybe_autocast():
+                image_embeds = self.ln_vision(self.visual_encoder(image))
+            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+            if self.qformer_text_input:
+                query_output = self.Qformer.bert(
+                    text_Qformer.input_ids,
+                    attention_mask=Qformer_atts,
+                    query_embeds=query_tokens,
+                    encoder_hidden_states=image_embeds,
+                    encoder_attention_mask=image_atts,
+                    return_dict=True,
+                )
+            else:
+                query_output = self.Qformer.bert(
+                    query_embeds=query_tokens,
+                    encoder_hidden_states=image_embeds,
+                    encoder_attention_mask=image_atts,
+                    return_dict=True,
+                )
+
+            inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
+            atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
+
+        input_tokens = self.t5_tokenizer(
+            prompt,
+            padding="longest",
+            return_tensors="pt"
+        ).to(image.device)
+
+        encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
+
+        with self.maybe_autocast(dtype=torch.bfloat16):
+            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
+            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
+
+            outputs = self.t5_model.generate(
+                inputs_embeds=inputs_embeds,
+                attention_mask=encoder_atts,
+                do_sample=use_nucleus_sampling,
+                top_p=top_p,
+                temperature=temperature,
+                num_beams=num_beams,
+                max_new_tokens=max_length,
+                min_length=min_length,
+                repetition_penalty=repetition_penalty,
+                length_penalty=length_penalty,
+                num_return_sequences=num_captions,
+            )
+            output_text = self.t5_tokenizer.batch_decode(
+                outputs, skip_special_tokens=True
+            )
+
+        return output_text
+
+
+    # this is in use in eval.
+    def predict_answers(
+        self,
+        samples,
+        num_beams=5,
+        inference_method="generate",
+        max_len=10,
+        min_len=1,
+        num_ans_candidates=128,
+        answer_list=None,
+        prompt="",
+        length_penalty=-1,
+        **kwargs
+    ):
+        if isinstance(samples["text_input"], str):
+            samples["text_input"] = [samples["text_input"]]
+
+        if prompt:
+            if prompt.count("{}") == 2:
+                if 'ocr_tokens' in samples:
+                    text_input = [
+                        prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["text_input"][i])
+                    for i in range(len(samples["text_input"]))]
+                elif 'choices' in samples:
+                    text_input = []
+                    for i in range(len(samples["text_input"])):
+                        this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])]
+                        this_choices = " ".join(this_choices)
+                        text_input.append(prompt.format(samples["text_input"][i], this_choices))
+            else:
+                text_input = [prompt.format(question) for question in samples["text_input"]]
+        else:
+            text_input = samples["text_input"]
+
+        samples["prompt"] = text_input
+        
+        # print(text_input[0])
+        output_text = self.generate(
+            samples,
+            num_beams=num_beams,
+            max_length=max_len,
+            min_length=min_len,
+            length_penalty=length_penalty
+        )
+        # print(output_text[0])
+        if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]):
+            output_text = self._lemmatize(output_text)
+        
+
+        return output_text
+
+    def predict_class(
+        self,
+        samples,
+        candidates,
+        n_segments=1,
+    ):
+        # If candidates is a list of lists, each sample has its candidates, then we need to iterate one by one
+        if type(candidates[0]) == list:
+            results = []
+
+            for i in range(samples["image"].size(0)):
+                this_sample = {
+                    "image": samples["image"][i].unsqueeze(0),
+                    "prompt": samples["prompt"],
+                }
+
+                if "text_input" in samples.keys():
+                    this_sample["text_input"] = [samples["text_input"][i]]
+
+                if 'context' in samples.keys():
+                    this_sample['context'] = [samples["context"][i]]
+
+                if 'history' in samples.keys():
+                    this_sample['history'] = [samples["history"][i]]
+
+                if 'caption' in samples.keys():
+                    this_sample['caption'] = [samples["caption"][i]]
+
+                this_result = self._predict_class(this_sample, candidates[i], n_segments)
+                results.append(this_result)
+
+            try:
+                results = torch.cat(results, dim=0)
+            except:
+                results = [res.tolist()[0] for res in results]
+
+            return results
+
+        return self._predict_class(samples, candidates, n_segments)
+
+    def _predict_class(
+        self,
+        samples,
+        candidates,
+        n_segments=1,
+    ):
+        """
+        Args:
+            samples (dict): A dictionary containing the following keys:
+                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
+                - prompt: the instruction
+            candidates:
+                (list): A list of candidate class names;
+            n_segments:
+                (int): Split the candidates into n_segments and predict one by one. This is useful when the number of candidates is too large.
+        Returns:
+            output_class: predicted class index
+        """
+
+        image = samples["image"]
+        prompt = samples["prompt"]
+
+        bs = image.size(0)
+
+        if isinstance(prompt, str):
+            prompt = [prompt] * bs
+        else:
+            assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
+
+        if "text_input" in samples.keys():
+            if type(samples["text_input"][0]) == list:
+                prompt = [prompt[i].format(*samples["text_input"][i]) for i in range(len(prompt))]
+            else:
+                prompt = [prompt[i].format(samples["text_input"][i]) for i in range(len(prompt))]
+
+        # scienceqa
+        if 'context' in samples.keys() and samples['context'] != '':
+            prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
+
+        # visual dialog
+        if 'history' in samples.keys() and samples['history'][0] != '':
+            prompt = [f'dialog history: {samples["history"][i]}\n{prompt[i]}' for i in range(len(prompt))]
+
+        if 'caption' in samples.keys() and samples['caption'][0] != '':
+            prompt = [f'This image has the caption "{samples["caption"][i]}". {prompt[i]}' for i in range(len(prompt))]
+
+        query_tokens = self.query_tokens.expand(bs, -1, -1)
+        if self.qformer_text_input:
+            text_Qformer = self.tokenizer(
+                prompt,
+                padding='longest',
+                truncation=True,
+                max_length=self.max_txt_len,
+                return_tensors="pt"
+            ).to(image.device)
+            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+            Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask], dim=1)
+
+        if image.dim() == 5:
+            inputs_t5, atts_t5 = [], []
+            for j in range(image.size(2)):
+                this_frame = image[:,:,j,:,:]
+                with self.maybe_autocast():
+                    frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
+                    frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+                if self.qformer_text_input:
+                    frame_query_output = self.Qformer.bert(
+                        text_Qformer.input_ids,
+                        attention_mask=Qformer_atts,
+                        query_embeds=query_tokens,
+                        encoder_hidden_states=frame_embeds,
+                        encoder_attention_mask=frame_atts,
+                        return_dict=True,
+                    )
+                else:
+                    frame_query_output = self.Qformer.bert(
+                        query_embeds=query_tokens,
+                        encoder_hidden_states=frame_embeds,
+                        encoder_attention_mask=frame_atts,
+                        return_dict=True,
+                    )
+
+                frame_inputs_t5 = self.t5_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
+                frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
+                inputs_t5.append(frame_inputs_t5)
+                atts_t5.append(frame_atts_t5)
+            inputs_t5 = torch.cat(inputs_t5, dim=1)
+            atts_t5 = torch.cat(atts_t5, dim=1)
+        else:
+            with self.maybe_autocast():
+                image_embeds = self.ln_vision(self.visual_encoder(image))
+            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+            if self.qformer_text_input:
+                query_output = self.Qformer.bert(
+                    text_Qformer.input_ids,
+                    attention_mask=Qformer_atts,
+                    query_embeds=query_tokens,
+                    encoder_hidden_states=image_embeds,
+                    encoder_attention_mask=image_atts,
+                    return_dict=True,
+                )
+            else:
+                query_output = self.Qformer.bert(
+                    query_embeds=query_tokens,
+                    encoder_hidden_states=image_embeds,
+                    encoder_attention_mask=image_atts,
+                    return_dict=True,
+                )
+
+            inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
+            atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
+
+        input_tokens = self.t5_tokenizer(
+            prompt, padding="longest", return_tensors="pt"
+        ).to(image.device)
+        output_tokens = self.t5_tokenizer(
+            candidates, padding="longest", return_tensors="pt"
+        ).to(image.device)
+
+        encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
+
+        n_cands = len(candidates)
+
+        with self.maybe_autocast(dtype=torch.bfloat16):
+            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
+            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
+
+            encoder_outputs = self.t5_model.encoder(
+                inputs_embeds=inputs_embeds,
+                attention_mask=encoder_atts,
+            )
+
+            all_losses = []
+            for n in range(n_segments):
+                seg_len = n_cands // n_segments
+                if n == (n_segments - 1):
+                    seg_len = n_cands - seg_len * (n_segments - 1)
+
+                # this_encoder_outputs = copy.deepcopy(encoder_outputs)
+                this_encoder_outputs = BaseModelOutput(
+                    last_hidden_state=encoder_outputs[0].clone(),
+                )
+
+                this_encoder_outputs['last_hidden_state'] = this_encoder_outputs[0].repeat_interleave(seg_len, dim=0)
+                this_encoder_atts = encoder_atts.repeat_interleave(seg_len, dim=0)
+
+                start_i = n * (n_cands // n_segments)
+                end_i = start_i + seg_len
+                this_output_tokens_ids = output_tokens.input_ids[start_i:end_i].repeat(bs, 1)
+                this_output_tokens_atts = output_tokens.attention_mask[start_i:end_i].repeat(bs, 1)
+
+                this_targets = this_output_tokens_ids.masked_fill(this_output_tokens_ids == self.t5_tokenizer.pad_token_id, -100)
+
+                outputs = self.t5_model(
+                    encoder_outputs=this_encoder_outputs,
+                    attention_mask=this_encoder_atts,
+                    decoder_attention_mask=this_output_tokens_atts,
+                    return_dict=True,
+                    labels=this_targets,
+                    reduction="none",
+                )
+                loss = outputs.loss
+
+                loss = loss.reshape(bs, seg_len)
+                # output_class_ranks = torch.argsort(loss, dim=-1)
+                all_losses.append(loss)
+
+            all_losses = torch.cat(all_losses, dim=-1)
+            output_class_ranks = torch.argsort(all_losses, dim=-1)
+
+            # encoder_outputs['last_hidden_state'] = encoder_outputs[0].repeat_interleave(n_cands, dim=0)
+            # encoder_atts = encoder_atts.repeat_interleave(n_cands, dim=0)
+            # output_tokens.input_ids = output_tokens.input_ids.repeat(bs, 1)
+            # output_tokens.attention_mask = output_tokens.attention_mask.repeat(bs, 1)
+
+            # # compute the LM loss for each candidate (sum logprob across all tokens) and select the highest
+            # targets = output_tokens.input_ids.masked_fill(output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100)
+
+            # outputs = self.t5_model(
+            #     encoder_outputs=encoder_outputs,
+            #     attention_mask=encoder_atts,
+            #     decoder_attention_mask=output_tokens.attention_mask,
+            #     return_dict=True,
+            #     labels=targets,
+            #     reduction="none",
+            # )
+            # loss = outputs.loss
+
+            # loss = loss.reshape(bs, n_cands)
+            # output_class_ranks = torch.argsort(loss, dim=-1) # (bs, num_candidates)
+
+        return output_class_ranks
+
+    def _lemmatize(self, answers):
+        def apply(answer):
+            doc = self.lemmatizer(answer)
+
+            words = []
+            for token in doc:
+                if token.pos_ in ["NOUN", "VERB"]:
+                    words.append(token.lemma_)
+                else:
+                    words.append(token.text)
+            answer = " ".join(words)
+
+            return answer
+
+        return [apply(answer) for answer in answers]
+
+    @property
+    def lemmatizer(self):
+        if self._lemmatizer is None:
+            try:
+                import spacy
+
+                self._lemmatizer = spacy.load("en_core_web_sm")
+            except ImportError:
+                logging.error(
+                    """
+                    Please install spacy and en_core_web_sm model to apply lemmatization.
+                    python -m spacy download en_core_web_sm
+                    OR
+                    import spacy.cli
+                    spacy.cli.download("en_core_web_sm")
+                    """
+                )
+                exit(1)
+
+        return self._lemmatizer
+
+    @classmethod
+    def from_config(cls, cfg):
+        vit_model = cfg.get("vit_model", "eva_clip_g")
+        img_size = cfg.get("image_size")
+        num_query_token = cfg.get("num_query_token")
+        t5_model = cfg.get("t5_model")
+
+        drop_path_rate = cfg.get("drop_path_rate", 0)
+        use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
+        vit_precision = cfg.get("vit_precision", "fp16")
+        freeze_vit = cfg.get("freeze_vit", True)
+
+        prompt = cfg.get("prompt", "")
+        max_txt_len = cfg.get("max_txt_len", 128)
+        max_output_txt_len = cfg.get("max_output_txt_len", 256)
+
+        apply_lemmatizer = cfg.get("apply_lemmatizer", False)
+
+        num_few_shot_examples = cfg.get("num_few_shot_examples", 0)
+        few_shot_prob = cfg.get("few_shot_prob", 0.0)
+
+        qformer_text_input = cfg.get("qformer_text_input", True)
+
+        # kd
+        alpha = cfg.get("alpha", 1) # no ckd
+        kd_loss = cfg.get("kd_loss", 'kd') # no ckd
+
+        model = cls(
+            vit_model=vit_model,
+            img_size=img_size,
+            drop_path_rate=drop_path_rate,
+            use_grad_checkpoint=use_grad_checkpoint,
+            vit_precision=vit_precision,
+            freeze_vit=freeze_vit,
+            num_query_token=num_query_token,
+            t5_model=t5_model,
+            prompt=prompt,
+            max_txt_len=max_txt_len,
+            max_output_txt_len=max_output_txt_len,
+            apply_lemmatizer=apply_lemmatizer,
+            num_few_shot_examples=num_few_shot_examples,
+            few_shot_prob=few_shot_prob,
+            qformer_text_input=qformer_text_input,
+            kd_loss=kd_loss,
+            alpha=alpha,
+        )
+
+        # if qformer_text_input:
+        #     # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal)
+        #     model.load_from_pretrained(
+        #         url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth"
+        #     )
+
+        model.load_checkpoint_from_config(cfg)
+
+        return model
diff --git a/lavis_ckd/models/blip2_models/blip2_vicuna_instruct_ckd.py b/lavis/models/blip2_models/blip2_vicuna_instruct_ckd.py
new file mode 100644
index 0000000..e65d03e
--- /dev/null
+++ b/lavis/models/blip2_models/blip2_vicuna_instruct_ckd.py
@@ -0,0 +1,1062 @@
+import logging
+import string
+from packaging import version
+
+import torch
+from torch.cuda.amp import autocast as autocast
+import torch.nn as nn
+import torch.nn.functional as F
+
+import transformers
+
+from lavis.common.registry import registry
+from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
+
+@registry.register_model("blip2_vicuna_instruct_ckd")
+class Blip2VicunaInstructCKD(Blip2Base):
+    """
+    BLIP2 Vicuna model.
+    Supported model types:
+        - vicuna7b
+        - vicuna13b
+    Usage:
+        >>> from lavis.models import load_model
+        >>> model = load_model("blip2_vicuna_instruct_ckd", "vicuna7b")
+    """
+
+    PRETRAINED_MODEL_CONFIG_DICT = {
+        "vicuna7b": "configs/models/blip2/blip2_instruct_ckd_vicuna7b.yaml",
+        "vicuna13b": "configs/models/blip2/blip2_instruct_ckd_vicuna13b.yaml",
+    }
+
+    def __init__(
+        self,
+        vit_model="eva_clip_g",
+        img_size=224,
+        drop_path_rate=0,
+        use_grad_checkpoint=False,
+        vit_precision="fp16",
+        freeze_vit=True,
+        num_query_token=32,
+        llm_model="",
+        prompt="",
+        max_txt_len=128,
+        max_output_txt_len=256,
+        apply_lemmatizer=False,
+        qformer_text_input=True,
+        kd_loss='kd', 
+        alpha=0,
+    ):
+        super().__init__()
+        transformers_version = version.parse(transformers.__version__)
+        assert transformers_version >= version.parse("4.28"), "BLIP-2 Vicuna requires transformers>=4.28"        
+        from transformers import LlamaTokenizer
+        from lavis.models.blip2_models.modeling_llama import LlamaForCausalLM
+        
+        assert kd_loss in ['kd', 'ckd', 'ckd-mask']
+        self.kd_loss = kd_loss
+        self.alpha = alpha
+        self.tokenizer = self.init_tokenizer(truncation_side="left")
+
+        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
+            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
+        )
+        if freeze_vit:
+            for name, param in self.visual_encoder.named_parameters():
+                param.requires_grad = False
+            self.visual_encoder = self.visual_encoder.eval()
+            self.visual_encoder.train = disabled_train
+            logging.info("freeze vision encoder")
+        else:
+            logging.info("train vision encoder")
+
+        self.Qformer, self.query_tokens = self.init_Qformer(
+            num_query_token, self.visual_encoder.num_features
+        )
+
+        if not qformer_text_input:
+            self.Qformer.bert.embeddings.word_embeddings = None
+            self.Qformer.bert.embeddings.position_embeddings = None
+            for layer in self.Qformer.bert.encoder.layer:
+                layer.output = None
+                layer.intermediate = None
+        else:
+            self.Qformer.resize_token_embeddings(len(self.tokenizer))
+        self.Qformer.cls = None
+
+        self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_model, use_fast=False, truncation_side="left")
+        self.llm_model = LlamaForCausalLM.from_pretrained(
+            llm_model, torch_dtype=torch.float16
+        )
+        self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
+        self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
+        self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
+        self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
+        # self.llm_tokenizer.pad_token = self.llm_tokenizer.unk_token
+
+        self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
+
+        # self.eos_token_id = self.llm_tokenizer(
+        #     self.llm_tokenizer.eos_token, add_special_tokens=False
+        # ).input_ids[0]
+
+        for name, param in self.llm_model.named_parameters():
+            param.requires_grad = False
+
+        self.llm_proj = nn.Linear(
+            self.Qformer.config.hidden_size, self.llm_model.config.hidden_size
+        )
+
+        self.max_txt_len = max_txt_len
+        self.max_output_txt_len = max_output_txt_len
+        self.prompt = prompt
+        prompt_tokens = self.llm_tokenizer(self.prompt, return_tensors="pt")
+        self.prompt_length = prompt_tokens.attention_mask.sum(1)
+
+        self._lemmatizer = None
+
+        self.qformer_text_input = qformer_text_input
+
+        n_parameters_train = sum(p.numel() for p in self.parameters() if p.requires_grad)/ 1.e6
+        n_parameters_total = sum(p.numel() for p in self.parameters())/ 1.e6
+        logging.info(f"total trainable parameter {n_parameters_train} million - total parameter {n_parameters_total} million")
+
+
+    def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts):
+        input_part_targets_len = []
+        llm_tokens = {"input_ids": [], "attention_mask": []}
+        for i in range(input_ids.size(0)):
+            this_input_ones = input_atts[i].sum()
+            input_part_targets_len.append(this_input_ones)
+            llm_tokens['input_ids'].append(
+                torch.cat([
+                    input_ids[i][:this_input_ones],
+                    output_ids[i][1:],
+                    input_ids[i][this_input_ones:]
+                ])
+            )
+            llm_tokens['attention_mask'].append(
+                torch.cat([
+                    input_atts[i][:this_input_ones],
+                    output_atts[i][1:],
+                    input_atts[i][this_input_ones:]
+                ])
+            )
+        llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
+        llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask'])
+        return llm_tokens, input_part_targets_len
+
+
+    def concat_input_pos_neg(self, input_ids, input_atts, 
+                                 pos_ids, pos_atts,
+                                 neg_ids, neg_atts):
+        total_len = input_ids.shape[1]+pos_ids.shape[1]+neg_ids.shape[1]
+        input_part_targets_len = []
+        sign = []
+        llm_tokens = {"input_ids": [], "attention_mask": []}
+        for i in range(input_ids.size(0)):
+            this_input_ones = input_atts[i].sum()
+            input_part_targets_len.append(this_input_ones)
+
+
+            pos_len = pos_atts[i].sum()
+            neg_len = neg_atts[i].sum()
+            llm_tokens['input_ids'].append(
+                torch.cat([
+                    input_ids[i][:this_input_ones],
+                    pos_ids[i][1:pos_len], # removed bos
+                    neg_ids[i][:neg_len], # keep bos; pos and neg parts are not related # TODO: verify w/ team
+                    input_ids[i][this_input_ones:], # following are ignored parts
+                    pos_ids[i][pos_len:],
+                    neg_ids[i][neg_len:],
+                ])
+            )
+            llm_tokens['attention_mask'].append(
+                torch.cat([
+                    input_atts[i][:this_input_ones],
+                    pos_atts[i][1:pos_len], # removed bos
+                    neg_atts[i][:neg_len], # keep bos; pos and neg parts are not related # TODO: verify w/ team
+                    input_atts[i][this_input_ones:], # following are ignored parts
+                    pos_atts[i][pos_len:],
+                    neg_atts[i][neg_len:],
+                ])
+            )
+            sign.append(torch.cat([
+                    torch.ones(this_input_ones)*-100, # input ignored
+                    torch.ones(pos_len-1)*1, # positive # removed bos
+                    torch.ones(neg_len)*-1, # negative # keep bos # # TODO: verify w/ team
+                    torch.ones((len(input_atts[i])-this_input_ones)+
+                               (len(pos_atts[i])-pos_len)+
+                               (len(neg_atts[i])-neg_len))*-100, # pads ignored
+                ]))
+            
+        llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
+        llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask'])
+        sign = torch.stack(sign) # 
+        return llm_tokens, input_part_targets_len, sign
+
+
+    def forward(self, samples):
+        # print('-----------------')
+        # print(samples["text_input"])
+        # print(samples["text_output"])
+        # print('-----------------')
+        DEBUG = True if samples['epoch']==0 and samples['iters']==0 else False
+        use_negatives=False
+        if self.kd_loss.startswith('ckd'):
+            use_negatives=True
+
+        image = samples["image"]
+        with self.maybe_autocast():
+            image_embeds = self.ln_vision(self.visual_encoder(image))
+        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+        bs = image.size(0)
+        text_input = samples['text_input']
+        pos_descrition = samples['pos_descrition']
+        if use_negatives:
+            neg_descrition = samples['neg_descrition']
+        else:
+            neg_descrition = ['None']
+        
+        if DEBUG:
+            print(f"EPOCH {samples['epoch']}",  'text_input:', text_input[0])
+            print(f"EPOCH {samples['epoch']}",  
+                  'text_output: Positive:', pos_descrition[0]+' Negative: '+neg_descrition[0])
+
+        # tokenize
+        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+        if self.qformer_text_input:
+            text_Qformer = self.tokenizer(
+                text_input,
+                padding='longest',
+                truncation=True,
+                max_length=self.max_txt_len,
+                return_tensors="pt",
+            ).to(image.device)
+            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+            Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],dim=1)
+
+            query_output = self.Qformer.bert(
+                text_Qformer.input_ids,
+                attention_mask=Qformer_atts,
+                query_embeds=query_tokens,
+                encoder_hidden_states=image_embeds,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+        else:
+            query_output = self.Qformer.bert(
+                query_embeds=query_tokens,
+                encoder_hidden_states=image_embeds,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+
+        inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
+        atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+
+        self.llm_tokenizer.padding_side = "right"
+        self.llm_tokenizer.truncation_side = 'left'
+        text_input_tokens = self.llm_tokenizer(
+            text_input,
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=self.max_txt_len,
+        ).to(image.device)
+
+        self.llm_tokenizer.truncation_side = 'right'
+        text_pos_rat_tokens = self.llm_tokenizer(
+            [t + self.llm_tokenizer.eos_token for t in pos_descrition], # TODO: recheck eos_token; adding eos as pos part is not related to neg
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=self.max_output_txt_len,
+        ).to(image.device)
+
+        if use_negatives:
+            text_neg_rat_tokens = self.llm_tokenizer(
+                [t + self.llm_tokenizer.eos_token for t in neg_descrition], # TODO: recheck eos_token
+                return_tensors="pt",
+                padding="longest",
+                truncation=True,
+                max_length=self.max_output_txt_len,
+            ).to(image.device)
+
+        # merge tokens
+        if use_negatives:
+            llm_tokens, input_part_targets_len, sign = self.concat_input_pos_neg(
+                text_input_tokens.input_ids,
+                text_input_tokens.attention_mask,
+                text_pos_rat_tokens.input_ids,
+                text_pos_rat_tokens.attention_mask,
+                text_neg_rat_tokens.input_ids,
+                text_neg_rat_tokens.attention_mask,
+            )
+        else:
+            llm_tokens, input_part_targets_len = self.concat_text_input_output(
+                text_input_tokens.input_ids,
+                text_input_tokens.attention_mask,
+                text_pos_rat_tokens.input_ids,
+                text_pos_rat_tokens.attention_mask,
+            )
+
+        # do not apply loss to the padding
+        targets = llm_tokens['input_ids'].masked_fill(
+            llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100
+        )
+
+        # do not apply loss to the text input (i.e., instruction)
+        for i, l in enumerate(input_part_targets_len):
+            targets[i][:l] = -100
+
+        # do not apply loss to the query tokens
+        empty_targets = (
+            torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
+        )
+        targets = torch.cat([empty_targets, targets], dim=1)
+
+        if use_negatives:
+            sign = sign.type(torch.long).to(image.device)
+            sign = torch.cat([empty_targets, sign], dim=1)
+
+        inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids'])
+        inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+        attention_mask = torch.cat([atts_llm, llm_tokens['attention_mask']], dim=1)
+
+
+        ####### no contrastive kd
+        if not use_negatives: 
+            with self.maybe_autocast():
+                outputs = self.llm_model(
+                    inputs_embeds=inputs_embeds,
+                    attention_mask=attention_mask,
+                    return_dict=True,
+                    return_lm_logits_seq_out=True, 
+                    labels=targets,
+                ) # loss, logits, hidden_states
+
+            loss = outputs[0]
+            return {"loss": loss, 'ce_loss': loss}
+
+
+        ####### contrastive kd 
+
+        if self.kd_loss == 'ckd': 
+
+            labels = targets
+
+            # do not calculate loss for the negative desc
+            targets = targets.masked_fill(
+                sign == -1, -100
+                )
+            
+            with self.maybe_autocast():
+                outputs = self.llm_model(
+                    inputs_embeds=inputs_embeds,
+                    attention_mask=attention_mask,
+                    return_dict=True,
+                    return_lm_logits_seq_out=True, 
+                    labels=targets,
+                ) # loss, logits, hidden_states
+
+            pos_loss, logits, hidden_states = outputs
+            
+            # contrastive loss
+            pos_logits=[]
+            neg_targets=[]
+
+            shift_logits = logits[..., :-1, :].contiguous()
+            logit_sign = sign[..., :-1].contiguous() # matching with logits
+
+            shift_labels = labels[..., 1:].contiguous()
+            label_sign = sign[..., 1:].contiguous() # matching with labels
+            for k in range(bs):
+                # logit
+                _logit_sign = logit_sign[k].cpu()
+                pos_len = (_logit_sign==1).sum().item()
+                pos_start = _logit_sign.numpy().tolist().index(1) # first found location
+
+                # target
+                _label_sign = label_sign[k].cpu()
+                neg_len = (_label_sign==-1).sum().item()
+                neg_start = _label_sign.numpy().tolist().index(-1) # first found location
+                
+                # logit-target
+                _len = min(pos_len, neg_len)
+                _pos_logits = shift_logits[k, pos_start:pos_start+_len] # we are avoiding zero-padding and trimming
+                _neg_targets = shift_labels[k, neg_start:neg_start+_len]
+                # pad to make size same
+
+                pos_logits.append(_pos_logits)
+                neg_targets.append(_neg_targets)
+
+            pos_logits = torch.cat(pos_logits)
+            neg_targets = torch.cat(neg_targets)
+
+            neg_loss = F.nll_loss(torch.log(torch.clamp((1.0 - F.softmax(pos_logits)), min=1e-5)), neg_targets, reduction='mean', ignore_index=-100)
+            loss = pos_loss * self.alpha + neg_loss * (1-self.alpha)
+
+            return {"loss": loss, 'pos_loss': pos_loss, 'neg_loss': neg_loss}
+
+
+        # TODO: add mask self.llm_model.get_input_embeddings()(llm_tokens['input_ids'])
+        elif self.kd_loss == 'ckd-mask':
+
+            labels = targets
+
+            # do not calculate loss for the negative desc
+            targets = targets.masked_fill(
+                sign == -1, -100
+                )
+            
+            with self.maybe_autocast():
+                outputs = self.llm_model(
+                    inputs_embeds=inputs_embeds,
+                    attention_mask=attention_mask,
+                    return_dict=True,
+                    return_lm_logits_seq_out=True, 
+                    labels=targets,
+                ) # loss, logits, hidden_states
+
+            pos_loss, logits, hidden_states = outputs
+            
+            # contrastive loss
+            pos_logits=[]
+            neg_targets=[]
+
+            shift_logits = logits[..., :-1, :].contiguous()
+            logit_sign = sign[..., :-1].contiguous() # matching with logits
+            shift_labels = labels[..., 1:].contiguous()
+            label_sign = sign[..., 1:].contiguous() # matching with labels
+            max_pos_len=0
+            max_neg_len=0
+            for k in range(bs):
+                # logit
+                _logit_sign = logit_sign[k].cpu()
+                pos_len = (_logit_sign==1).sum().item()
+                pos_start = _logit_sign.numpy().tolist().index(1) # first found location
+
+                # target
+                _label_sign = label_sign[k].cpu()
+                neg_len = (_label_sign==-1).sum().item()
+                neg_start = _label_sign.numpy().tolist().index(-1) # first found location
+
+                _pos_logits = shift_logits[k, pos_start:pos_start+pos_len] 
+                _neg_targets = shift_labels[k, neg_start:neg_start+neg_len]
+                
+                pos_logits.append(_pos_logits)
+                neg_targets.append(_neg_targets)
+
+                if max_pos_len<pos_len:
+                    max_pos_len=pos_len
+                if max_neg_len<neg_len:
+                    max_neg_len=neg_len
+
+            # zero pad
+            pos_logits = [self.pad_tensors(pl, max_pos_len, 0) for pl in pos_logits]
+            neg_targets = [self.pad_tensors(nt, max_neg_len, -100) for nt in neg_targets]
+            
+            # stack
+            pos_logits = torch.stack(pos_logits)
+            neg_targets = torch.stack(neg_targets)
+
+            _len = min(max_pos_len, max_neg_len)
+            pos_logits = pos_logits[:, :_len, :].contiguous()
+            neg_targets = neg_targets[:, :_len].contiguous()
+            
+            # calculate mask to apply neg loss only on negative objects
+            # we removed bos in pos_embed during training adjusted _pos_embed
+            _pos_embed = self.get_frozen_embed(text_pos_rat_tokens)[1][:, 1:, :]
+            _neg_embed = self.get_frozen_embed(text_neg_rat_tokens)[1]
+            _pos_atts = text_pos_rat_tokens.attention_mask[:, 1:]
+            _neg_atts = text_neg_rat_tokens.attention_mask
+
+            pos_len=_pos_embed.shape[1]
+            neg_len=_neg_embed.shape[1]
+            _len = min(pos_len, neg_len)
+            # trim to make shape same
+            _pos_embed=_pos_embed[:, :_len, :]
+            _neg_embed=_neg_embed[:, :_len, :]
+            _pos_atts=_pos_atts[:, :_len]
+            _neg_atts=_neg_atts[:, :_len]
+
+            cross = F.cosine_similarity(_pos_embed, _neg_embed, dim=-1) # -1 to 1 -> (add +1) -> 0 to 2 -> (divide by 2) -> 0 to 1
+            cross = (1+cross)/2 
+            inv_similarity = 1-cross # use this to multiply the loss
+
+            pos_logits = pos_logits.view(-1, pos_logits.size(-1))
+            neg_targets = neg_targets.view(-1)
+            inv_similarity = inv_similarity.view(-1)
+
+            neg_loss = F.nll_loss(torch.log(torch.clamp((1.0 - F.softmax(pos_logits)), min=1e-5)), neg_targets, reduction='none', ignore_index=-100)
+            neg_loss = neg_loss*inv_similarity
+            neg_loss = (neg_loss * (neg_targets != -100).float()).sum()
+            neg_tok_num = max((neg_targets.reshape(-1) != -100).float().sum(), 1.0)
+            neg_loss = neg_loss / neg_tok_num
+            neg_loss = neg_loss*self.alpha
+
+            # loss = pos_loss * self.alpha + neg_loss * (1-self.alpha)
+            loss = pos_loss + neg_loss
+
+            return {"loss": loss, 'pos_loss': pos_loss, 'neg_loss': neg_loss}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+    def pad_zero_to_match_shape(self, 
+                                tensor_a, # larger
+                                tensor_b, # smaller
+                                ):
+        
+        # tensor_a = torch.randn(32, 43)
+        # tensor_b = torch.randn(32, 33)
+        
+        assert tensor_a.shape[1]>tensor_b.shape[1], "first item should be the larger one"
+
+        tensor_b_padded = torch.zeros_like(tensor_a)
+        tensor_b_padded[:, :tensor_b.shape[1]] = tensor_b
+
+        return tensor_b_padded
+    
+    def pad_tensors(self, tensor, max_len, pad_with=0):
+        assert len(tensor.shape)==2 or len(tensor.shape)==1
+        if len(tensor.shape)==2:
+            padded_tensor = torch.ones((max_len, tensor.shape[1]), dtype=tensor.dtype, device=tensor.device)*pad_with
+            padded_tensor[:tensor.shape[0]] = tensor
+        elif len(tensor.shape)==1:
+            padded_tensor = torch.ones(max_len, dtype=tensor.dtype, device=tensor.device)*pad_with
+            padded_tensor[:tensor.shape[0]] = tensor
+
+        return padded_tensor
+
+
+    @torch.no_grad()
+    def get_frozen_embed(self, token):
+        inputs_embeds = self.llm_model.get_input_embeddings()(token.input_ids)
+        attention_mask = token.attention_mask
+        with self.maybe_autocast():
+            outputs = self.llm_model(
+                inputs_embeds=inputs_embeds,
+                attention_mask=attention_mask,
+                return_dict=True,
+                return_lm_logits_seq_out=True, 
+                labels=None,
+            ) 
+
+        # logits, hidden_states = outputs
+        return outputs
+
+        
+
+    @torch.no_grad()
+    def generate(
+        self,
+        samples,
+        use_nucleus_sampling=False,
+        num_beams=5,
+        max_length=256,
+        min_length=1,
+        top_p=0.9,
+        repetition_penalty=1.5,
+        length_penalty=1,
+        num_captions=1,
+        temperature=1,
+    ):
+        self.llm_tokenizer.padding_side = "left"
+
+        if "prompt" in samples.keys():
+            prompt = samples["prompt"]
+        else:
+            prompt = self.prompt
+
+        image = samples["image"]
+
+        bs = image.size(0)
+
+        if isinstance(prompt, str):
+            prompt = [prompt] * bs
+        else:
+            assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
+
+        # For TextCaps
+        if "ocr_tokens" in samples.keys() and "{}" in prompt[0]:
+            prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)]
+
+        query_tokens = self.query_tokens.expand(bs, -1, -1)
+        if self.qformer_text_input:
+            # remove ocr tokens in q_former (for eval textvqa)
+            # qformer_prompt = prompt
+            # qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt]
+
+            text_Qformer = self.tokenizer(
+                prompt,
+                padding='longest',
+                truncation=True,
+                max_length=self.max_txt_len,
+                return_tensors="pt",
+            ).to(image.device)
+            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+            Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
+
+        # For video data
+        if image.dim() == 5:
+            inputs_llm, atts_llm = [], []
+            for j in range(image.size(2)):
+                this_frame = image[:,:,j,:,:]
+                with self.maybe_autocast():
+                    frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
+                frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+                if self.qformer_text_input:
+                    frame_query_output = self.Qformer.bert(
+                        text_Qformer.input_ids,
+                        attention_mask=Qformer_atts,
+                        query_embeds=query_tokens,
+                        encoder_hidden_states=frame_embeds,
+                        encoder_attention_mask=frame_atts,
+                        return_dict=True,
+                    )
+                else:
+                    frame_query_output = self.Qformer.bert(
+                        query_embeds=query_tokens,
+                        encoder_hidden_states=frame_embeds,
+                        encoder_attention_mask=frame_atts,
+                        return_dict=True,
+                    )
+                frame_inputs_llm = self.llm_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
+                frame_atts_llm = torch.ones(frame_inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+                inputs_llm.append(frame_inputs_llm)
+                atts_llm.append(frame_atts_llm)
+            inputs_llm = torch.cat(inputs_llm, dim=1)
+            atts_llm = torch.cat(atts_llm, dim=1)
+        else:
+            with self.maybe_autocast():
+                image_embeds = self.ln_vision(self.visual_encoder(image))
+            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+            if self.qformer_text_input:
+                query_output = self.Qformer.bert(
+                    text_Qformer.input_ids,
+                    attention_mask=Qformer_atts,
+                    query_embeds=query_tokens,
+                    encoder_hidden_states=image_embeds,
+                    encoder_attention_mask=image_atts,
+                    return_dict=True,
+                )
+            else:
+                query_output = self.Qformer.bert(
+                    query_embeds=query_tokens,
+                    encoder_hidden_states=image_embeds,
+                    encoder_attention_mask=image_atts,
+                    return_dict=True,
+                )
+
+            inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
+            atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+
+        llm_tokens = self.llm_tokenizer(
+            prompt,
+            padding="longest",
+            return_tensors="pt"
+        ).to(image.device)
+
+        # print('EVAL INPUT: ', prompt)
+        with self.maybe_autocast():
+            inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids)
+            inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+            attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1)
+
+            outputs = self.llm_model.generate(
+                inputs_embeds=inputs_embeds,
+                attention_mask=attention_mask,
+                do_sample=use_nucleus_sampling,
+                top_p=top_p,
+                temperature=temperature,
+                num_beams=num_beams,
+                max_length=max_length,
+                min_length=min_length,
+                # eos_token_id=self.eos_token_id,
+                repetition_penalty=repetition_penalty,
+                length_penalty=length_penalty,
+                num_return_sequences=num_captions,
+            )
+
+        outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id)
+        output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True)
+        output_text = [text.strip() for text in output_text]
+
+        # print('EVAL OUTPUT: ', output_text)
+
+        return output_text
+
+    def predict_answers(
+        self,
+        samples,
+        num_beams=5,
+        inference_method="generate",
+        max_len=10,
+        min_len=1,
+        num_ans_candidates=128,
+        answer_list=None,
+        prompt="",
+        length_penalty=0,
+        **kwargs
+    ):
+        if isinstance(samples["text_input"], str):
+            samples["text_input"] = [samples["text_input"]]
+
+        if prompt:
+            if prompt.count("{}") == 2:
+                if 'ocr_tokens' in samples:
+                    text_input = [
+                        prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["text_input"][i])
+                    for i in range(len(samples["text_input"]))]
+                elif 'choices' in samples:
+                    text_input = []
+                    for i in range(len(samples["text_input"])):
+                        this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])]
+                        this_choices = " ".join(this_choices)
+                        text_input.append(prompt.format(samples["text_input"][i], this_choices))
+            else:
+                text_input = [prompt.format(question) for question in samples["text_input"]]
+        else:
+            text_input = samples["text_input"]
+
+        samples["prompt"] = text_input
+
+        output_text = self.generate(
+            samples,
+            num_beams=num_beams,
+            max_length=max_len,
+            min_length=min_len,
+            length_penalty=length_penalty
+        )
+
+        if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
+            output_text = self._lemmatize(output_text)
+
+        return output_text
+
+    def predict_class(
+        self,
+        samples,
+        candidates,
+        n_segments=1,
+    ):
+        self.llm_tokenizer.padding_side = "left"
+
+        # If candidates is a list of lists, each sample has its candidates, then we need to iterate one by one
+        if type(candidates[0]) == list:
+            results = []
+
+            for i in range(samples["image"].size(0)):
+                this_sample = {
+                    "image": samples["image"][i].unsqueeze(0),
+                    "prompt": samples["prompt"],
+                }
+
+                if "text_input" in samples.keys():
+                    this_sample["text_input"] = [samples["text_input"][i]]
+
+                if 'context' in samples.keys():
+                    this_sample['context'] = [samples["context"][i]]
+
+                if 'history' in samples.keys():
+                    this_sample['history'] = [samples["history"][i]]
+
+                if 'caption' in samples.keys():
+                    this_sample['caption'] = [samples["caption"][i]]
+
+                this_result = self._predict_class(this_sample, candidates[i], n_segments)
+                results.append(this_result)
+
+            try:
+                results = torch.cat(results, dim=0)
+            except:
+                results = [res.tolist()[0] for res in results]
+
+            return results
+
+        return self._predict_class(samples, candidates, n_segments)
+
+    def _predict_class(
+        self,
+        samples,
+        candidates,
+        n_segments=1,
+    ):
+        image = samples["image"]
+        prompt = samples["prompt"]
+
+        bs = image.size(0)
+
+        if isinstance(prompt, str):
+            prompt = [prompt] * bs
+        else:
+            assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
+
+        if "text_input" in samples.keys():
+            if type(samples["text_input"][0]) == list:
+                prompt = [prompt[i].format(*samples["text_input"][i]) for i in range(len(prompt))]
+            else:
+                prompt = [prompt[i].format(samples["text_input"][i]) for i in range(len(prompt))]
+
+        # scienceqa
+        if 'context' in samples.keys() and samples['context'] != '':
+            prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
+
+        # visual dialog
+        if 'history' in samples.keys() and samples['history'][0] != '':
+            prompt = [f'dialog history: {samples["history"][i]}\n{prompt[i]}' for i in range(len(prompt))]
+
+        if 'caption' in samples.keys() and samples['caption'][0] != '':
+            prompt = [f'This image has the caption "{samples["caption"][i]}". {prompt[i]}' for i in range(len(prompt))]
+
+        query_tokens = self.query_tokens.expand(bs, -1, -1)
+        if self.qformer_text_input:
+            text_Qformer = self.tokenizer(
+                prompt,
+                padding='longest',
+                truncation=True,
+                max_length=self.max_txt_len,
+                return_tensors="pt"
+            ).to(image.device)
+            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+            Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
+
+        if image.dim() == 5:
+            inputs_llm, atts_llm = [], []
+            for j in range(image.size(2)):
+                this_frame = image[:,:,j,:,:]
+                with self.maybe_autocast():
+                    frame_embeds = self.ln_vision(self.visual_encoder(this_frame))
+                    frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+                if self.qformer_text_input:
+                    frame_query_output = self.Qformer.bert(
+                        text_Qformer.input_ids,
+                        attention_mask=Qformer_atts,
+                        query_embeds=query_tokens,
+                        encoder_hidden_states=frame_embeds,
+                        encoder_attention_mask=frame_atts,
+                        return_dict=True,
+                    )
+                else:
+                    frame_query_output = self.Qformer.bert(
+                        query_embeds=query_tokens,
+                        encoder_hidden_states=frame_embeds,
+                        encoder_attention_mask=frame_atts,
+                        return_dict=True,
+                    )
+
+                frame_inputs_llm = self.llm_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:])
+                frame_atts_llm = torch.ones(frame_inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+                inputs_llm.append(frame_inputs_llm)
+                atts_llm.append(frame_atts_llm)
+            inputs_llm = torch.cat(inputs_llm, dim=1)
+            atts_llm = torch.cat(atts_llm, dim=1)
+        else:
+            with self.maybe_autocast():
+                image_embeds = self.ln_vision(self.visual_encoder(image))
+            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+            if self.qformer_text_input:
+                query_output = self.Qformer.bert(
+                    text_Qformer.input_ids,
+                    attention_mask=Qformer_atts,
+                    query_embeds=query_tokens,
+                    encoder_hidden_states=image_embeds,
+                    encoder_attention_mask=image_atts,
+                    return_dict=True,
+                )
+            else:
+                query_output = self.Qformer.bert(
+                    query_embeds=query_tokens,
+                    encoder_hidden_states=image_embeds,
+                    encoder_attention_mask=image_atts,
+                    return_dict=True,
+                )
+
+            inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:])
+            atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+
+        self.llm_tokenizer.padding_side = "right"
+        self.llm_tokenizer.truncation_side = 'left'
+        text_input_tokens = self.llm_tokenizer(
+            prompt,
+            return_tensors="pt",
+            padding="longest",
+            # truncation=True,
+            # max_length=self.max_txt_len,
+        ).to(image.device)
+
+        empty_targets = torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
+
+        # self.llm_tokenizer.padding_side = "right"
+        self.llm_tokenizer.truncation_side = 'right'
+        n_cands = len(candidates)
+        with self.maybe_autocast(dtype=torch.bfloat16):
+            all_losses = []
+            for n in range(n_segments):
+                seg_len = n_cands // n_segments
+                if n == (n_segments - 1):
+                    seg_len = n_cands - seg_len * (n_segments - 1)
+
+                start_i = n * (n_cands // n_segments)
+                end_i = start_i + seg_len
+
+                this_output_tokens = self.llm_tokenizer(
+                    candidates[start_i:end_i],
+                    return_tensors="pt",
+                    padding="longest",
+                    # truncation=True,
+                    # max_length=self.max_output_txt_len,
+                ).to(image.device)
+
+                this_input_tokens_ids = text_input_tokens.input_ids.repeat_interleave(seg_len, dim=0)
+                this_input_tokens_atts = text_input_tokens.attention_mask.repeat_interleave(seg_len, dim=0)
+
+                this_output_tokens_ids = this_output_tokens.input_ids.repeat(bs, 1)
+                this_output_tokens_atts = this_output_tokens.attention_mask.repeat(bs, 1)
+
+                this_llm_tokens, this_input_targets_len = self.concat_text_input_output(
+                    this_input_tokens_ids,
+                    this_input_tokens_atts,
+                    this_output_tokens_ids,
+                    this_output_tokens_atts
+                )
+
+                this_llm_input_ids = this_llm_tokens['input_ids']
+                this_llm_atts = this_llm_tokens['attention_mask']
+                # this_llm_input_ids = torch.cat([this_input_tokens_ids, this_output_tokens_ids], dim=1)
+                # this_llm_atts = torch.cat([this_input_tokens_atts, this_output_tokens_atts], dim=1)
+
+                inputs_embeds = self.llm_model.get_input_embeddings()(this_llm_input_ids)
+                inputs_embeds = torch.cat([inputs_llm.repeat_interleave(seg_len, dim=0), inputs_embeds], dim=1)
+                attention_mask = torch.cat([atts_llm.repeat_interleave(seg_len, dim=0), this_llm_atts], dim=1)
+
+                this_targets = this_llm_input_ids.masked_fill(this_llm_input_ids == self.llm_tokenizer.pad_token_id, -100)
+                # this_targets[:, :this_input_tokens_ids.size(1)] = -100
+                for i, l in enumerate(this_input_targets_len):
+                    this_targets[i][:l] = -100
+
+                this_targets = torch.cat([empty_targets.repeat_interleave(seg_len, dim=0), this_targets], dim=1)
+
+                outputs = self.llm_model(
+                    inputs_embeds=inputs_embeds,
+                    attention_mask=attention_mask,
+                    return_dict=True,
+                    labels=this_targets,
+                    reduction="none",
+                )
+
+                loss = outputs.loss
+
+                loss = loss.reshape(bs, seg_len)
+                # output_class_ranks = torch.argsort(loss, dim=-1)
+                all_losses.append(loss)
+
+            all_losses = torch.cat(all_losses, dim=-1)
+            output_class_ranks = torch.argsort(all_losses, dim=-1)
+
+        return output_class_ranks
+
+    def _lemmatize(self, answers):
+        def apply(answer):
+            doc = self.lemmatizer(answer)
+
+            words = []
+            for token in doc:
+                if token.pos_ in ["NOUN", "VERB"]:
+                    words.append(token.lemma_)
+                else:
+                    words.append(token.text)
+            answer = " ".join(words)
+
+            return answer
+
+        return [apply(answer) for answer in answers]
+
+    @property
+    def lemmatizer(self):
+        if self._lemmatizer is None:
+            try:
+                import spacy
+
+                self._lemmatizer = spacy.load("en_core_web_sm")
+            except ImportError:
+                logging.error(
+                    """
+                    Please install spacy and en_core_web_sm model to apply lemmatization.
+                    python -m spacy download en_core_web_sm
+                    OR
+                    import spacy.cli
+                    spacy.cli.download("en_core_web_sm")
+                    """
+                )
+                exit(1)
+
+        return self._lemmatizer
+
+    @classmethod
+    def from_config(cls, cfg):
+        vit_model = cfg.get("vit_model", "eva_clip_g")
+        img_size = cfg.get("image_size")
+        num_query_token = cfg.get("num_query_token")
+        llm_model = cfg.get("llm_model")
+
+        drop_path_rate = cfg.get("drop_path_rate", 0)
+        use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
+        vit_precision = cfg.get("vit_precision", "fp16")
+        freeze_vit = cfg.get("freeze_vit", True)
+
+        prompt = cfg.get("prompt", "")
+        max_txt_len = cfg.get("max_txt_len", 128)
+        max_output_txt_len = cfg.get("max_output_txt_len", 256)
+
+        apply_lemmatizer = cfg.get("apply_lemmatizer", False)
+
+        qformer_text_input = cfg.get("qformer_text_input", True)
+
+        alpha = cfg.get("alpha", 1) # no ckd
+        kd_loss = cfg.get("kd_loss", 'kd') # no ckd
+
+        model = cls(
+            vit_model=vit_model,
+            img_size=img_size,
+            drop_path_rate=drop_path_rate,
+            use_grad_checkpoint=use_grad_checkpoint,
+            vit_precision=vit_precision,
+            freeze_vit=freeze_vit,
+            num_query_token=num_query_token,
+            llm_model=llm_model,
+            prompt=prompt,
+            max_txt_len=max_txt_len,
+            max_output_txt_len=max_output_txt_len,
+            apply_lemmatizer=apply_lemmatizer,
+            qformer_text_input=qformer_text_input,
+            alpha=alpha,
+            kd_loss=kd_loss,
+        )
+
+        # if qformer_text_input:
+        #     # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal)
+        #     model.load_from_pretrained(
+        #         url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth"
+        #     )
+
+        model.load_checkpoint_from_config(cfg)
+
+        return model
diff --git a/lavis_ckd/models/blip2_models/blip2_vicuna_instruct_ckd_lora.py b/lavis/models/blip2_models/blip2_vicuna_instruct_ckd_lora.py
new file mode 100644
index 0000000..45bba29
--- /dev/null
+++ b/lavis/models/blip2_models/blip2_vicuna_instruct_ckd_lora.py
@@ -0,0 +1,1024 @@
+import logging
+import string
+import os
+from packaging import version
+import torch
+from torch.cuda.amp import autocast as autocast
+import torch.nn as nn
+import torch.nn.functional as F
+import transformers
+from peft import LoraConfig, get_peft_model
+from lavis.common.registry import registry
+from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train
+from lavis.common.utils import is_url
+from lavis.common.dist_utils import download_cached_file
+@registry.register_model("blip2_vicuna_instruct_ckd_lora")
+class Blip2VicunaInstructCKDLoRA(Blip2Base):
+    """
+    BLIP2 Vicuna model.
+    Supported model types:
+        - vicuna7b
+        - vicuna13b
+    Usage:
+        >>> from lavis.models import load_model
+        >>> model = load_model("blip2_vicuna_instruct_ckd_lora", "vicuna7b")
+    """
+
+    PRETRAINED_MODEL_CONFIG_DICT = {
+        "vicuna7b": "configs/models/blip2/blip2_instruct_ckd_lora_vicuna7b.yaml",
+        "vicuna13b": "configs/models/blip2/blip2_instruct_ckd_lora_vicuna13b.yaml",  # not created
+    }
+
+    def __init__(
+        self,
+        vit_model="eva_clip_g",
+        img_size=224,
+        drop_path_rate=0,
+        use_grad_checkpoint=False,
+        vit_precision="fp16",
+        freeze_vit=True,
+        num_query_token=32,
+        llm_model="",
+        prompt="",
+        max_txt_len=128,
+        max_output_txt_len=256,
+        apply_lemmatizer=False,
+        qformer_text_input=True,
+        kd_loss='kd',
+        alpha=0,
+        llm_lora_r=8,
+        llm_lora_apply="attn",
+    ):
+        super().__init__()
+        transformers_version = version.parse(transformers.__version__)
+        assert transformers_version >= version.parse(
+            "4.28"), "BLIP-2 Vicuna requires transformers>=4.28"
+        from transformers import LlamaTokenizer
+        from lavis.models.blip2_models.modeling_llama import LlamaForCausalLM
+
+        assert kd_loss in ['kd', 'ckd', 'ckd-pos']
+        self.kd_loss = kd_loss
+        self.alpha = alpha
+        self.tokenizer = self.init_tokenizer(truncation_side="left")
+
+        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
+            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
+        )
+        if freeze_vit:
+            for name, param in self.visual_encoder.named_parameters():
+                param.requires_grad = False
+            self.visual_encoder = self.visual_encoder.eval()
+            self.visual_encoder.train = disabled_train
+            logging.info("freeze vision encoder")
+
+        self.Qformer, self.query_tokens = self.init_Qformer(
+            num_query_token, self.visual_encoder.num_features
+        )
+
+        if not qformer_text_input:
+            self.Qformer.bert.embeddings.word_embeddings = None
+            self.Qformer.bert.embeddings.position_embeddings = None
+            for layer in self.Qformer.bert.encoder.layer:
+                layer.output = None
+                layer.intermediate = None
+        else:
+            self.Qformer.resize_token_embeddings(len(self.tokenizer))
+        self.Qformer.cls = None
+
+        self.llm_tokenizer = LlamaTokenizer.from_pretrained(
+            llm_model, use_fast=False, truncation_side="left")
+        self.llm_model = LlamaForCausalLM.from_pretrained(
+            llm_model, torch_dtype=torch.float16
+        )
+        self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
+        self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
+        self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
+        self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
+        # self.llm_tokenizer.pad_token = self.llm_tokenizer.unk_token
+
+        self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
+
+        # self.eos_token_id = self.llm_tokenizer(
+        #     self.llm_tokenizer.eos_token, add_special_tokens=False
+        # ).input_ids[0]
+
+        for name, param in self.llm_model.named_parameters():
+            param.requires_grad = False
+
+        def _find_all_linear_names(model):
+            cls = torch.nn.Linear
+            lora_module_names = set()
+            for name, module in model.named_modules():
+                if isinstance(module, cls):
+                    names = name.split('.')
+                    lora_module_names.add(
+                        names[0] if len(names) == 1 else names[-1])
+
+            # if 'lm_head' in lora_module_names: # needed for 16-bit
+            #     lora_module_names.remove('lm_head')
+            return list(lora_module_names)
+
+        target_modules = []
+        if llm_lora_apply == "attn":
+            target_modules = ['q_proj', 'v_proj']
+        elif llm_lora_apply == "ffn":
+            target_modules = ['gate_proj', "up_proj", "down_proj"]
+        elif llm_lora_apply == "all":
+            target_modules = ['q_proj', 'v_proj',
+                              'gate_proj', "up_proj", "down_proj"]
+        else:
+            print("Wrong llm_lora_apply value in yaml!!")
+        print(f"applying llm lora on {llm_lora_apply}")
+        lora_config = LoraConfig(
+            r=llm_lora_r,
+            lora_alpha=2*llm_lora_r,
+            target_modules=target_modules,
+            # lora_dropout=training_args.lora_dropout,
+            # bias=training_args.lora_bias,
+            task_type="CAUSAL_LM",
+        )
+        self.llm_model = get_peft_model(self.llm_model, lora_config)
+        self.llm_model.print_trainable_parameters()
+
+        # TODO: add this in other lora setups; it's important to have float 32, otherwise loss may go to nan
+        for param in filter(lambda p: p.requires_grad, self.llm_model.parameters()):
+            param.data = param.data.to(torch.float32)
+
+        self.llm_proj = nn.Linear(
+            self.Qformer.config.hidden_size, self.llm_model.config.hidden_size
+        )
+
+        self.max_txt_len = max_txt_len
+        self.max_output_txt_len = max_output_txt_len
+        self.prompt = prompt
+        prompt_tokens = self.llm_tokenizer(self.prompt, return_tensors="pt")
+        self.prompt_length = prompt_tokens.attention_mask.sum(1)
+
+        self._lemmatizer = None
+
+        self.qformer_text_input = qformer_text_input
+
+    def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts):
+        input_part_targets_len = []
+        llm_tokens = {"input_ids": [], "attention_mask": []}
+        for i in range(input_ids.size(0)):
+            this_input_ones = input_atts[i].sum()
+            input_part_targets_len.append(this_input_ones)
+            llm_tokens['input_ids'].append(
+                torch.cat([
+                    input_ids[i][:this_input_ones],
+                    output_ids[i][1:],
+                    input_ids[i][this_input_ones:]
+                ])
+            )
+            llm_tokens['attention_mask'].append(
+                torch.cat([
+                    input_atts[i][:this_input_ones],
+                    output_atts[i][1:],
+                    input_atts[i][this_input_ones:]
+                ])
+            )
+        llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
+        llm_tokens['attention_mask'] = torch.stack(
+            llm_tokens['attention_mask'])
+        return llm_tokens, input_part_targets_len
+
+    def concat_input_ans_pos_neg(self, input_ids, input_atts,
+                                 pos_ids, pos_atts,
+                                 neg_ids, neg_atts):
+        total_len = input_ids.shape[1]+pos_ids.shape[1]+neg_ids.shape[1]
+        input_part_targets_len = []
+        sign = []
+        llm_tokens = {"input_ids": [], "attention_mask": []}
+        for i in range(input_ids.size(0)):
+            this_input_ones = input_atts[i].sum()
+            input_part_targets_len.append(this_input_ones)
+
+            pos_len = pos_atts[i].sum()
+            neg_len = neg_atts[i].sum()
+            llm_tokens['input_ids'].append(
+                torch.cat([
+                    input_ids[i][:this_input_ones],
+                    pos_ids[i][1:pos_len],  # removed bos
+                    # keep bos; pos and neg parts are not related # TODO: verify w/ team
+                    neg_ids[i][:neg_len],
+                    # following are ignored parts
+                    input_ids[i][this_input_ones:],
+                    pos_ids[i][pos_len:],
+                    neg_ids[i][neg_len:],
+                ])
+            )
+            llm_tokens['attention_mask'].append(
+                torch.cat([
+                    input_atts[i][:this_input_ones],
+                    pos_atts[i][1:pos_len],  # removed bos
+                    # keep bos; pos and neg parts are not related # TODO: verify w/ team
+                    neg_atts[i][:neg_len],
+                    # following are ignored parts
+                    input_atts[i][this_input_ones:],
+                    pos_atts[i][pos_len:],
+                    neg_atts[i][neg_len:],
+                ])
+            )
+            sign.append(torch.cat([
+                torch.ones(this_input_ones)*-100,  # input ignored
+                torch.ones(pos_len-1)*1,  # positive # removed bos
+                # negative # keep bos # # TODO: verify w/ team
+                torch.ones(neg_len)*-1,
+                torch.ones((len(input_atts[i])-this_input_ones) +
+                           (len(pos_atts[i])-pos_len) +
+                           (len(neg_atts[i])-neg_len))*-100,  # pads ignored
+            ]))
+
+        llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
+        llm_tokens['attention_mask'] = torch.stack(
+            llm_tokens['attention_mask'])
+        sign = torch.stack(sign)
+        return llm_tokens, input_part_targets_len, sign
+
+    def forward(self, samples):
+        # print('-----------------')
+        # print(samples["text_input"])
+        # print(samples["text_output"])
+        # print('-----------------')
+        DEBUG = True if samples['epoch'] == 0 and samples['iters'] == 0 else False
+        use_negatives = False
+        if self.kd_loss.startswith('ckd'):
+            use_negatives = True
+
+        image = samples["image"]
+        with self.maybe_autocast():
+            image_embeds = self.ln_vision(self.visual_encoder(image))
+        image_atts = torch.ones(
+            image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+        bs = image.size(0)
+        text_input = samples['text_input']
+        pos_descrition = samples['pos_descrition']
+        if use_negatives:
+            neg_descrition = samples['neg_descrition']
+        else:
+            neg_descrition = ['None']
+
+        if DEBUG:
+            print(f"EPOCH {samples['epoch']}",  'text_input:', text_input[0])
+            print(f"EPOCH {samples['epoch']}",
+                  'text_output: Positive:', pos_descrition[0]+' Negative: '+neg_descrition[0])
+
+        # tokenize
+        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+        if self.qformer_text_input:
+            text_Qformer = self.tokenizer(
+                text_input,
+                padding='longest',
+                truncation=True,
+                max_length=self.max_txt_len,
+                return_tensors="pt",
+            ).to(image.device)
+            query_atts = torch.ones(
+                query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+            Qformer_atts = torch.cat(
+                [query_atts, text_Qformer.attention_mask], dim=1)
+
+            query_output = self.Qformer.bert(
+                text_Qformer.input_ids,
+                attention_mask=Qformer_atts,
+                query_embeds=query_tokens,
+                encoder_hidden_states=image_embeds,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+        else:
+            query_output = self.Qformer.bert(
+                query_embeds=query_tokens,
+                encoder_hidden_states=image_embeds,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+
+        inputs_llm = self.llm_proj(
+            query_output.last_hidden_state[:, :query_tokens.size(1), :])
+        atts_llm = torch.ones(
+            inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+
+        self.llm_tokenizer.padding_side = "right"
+        self.llm_tokenizer.truncation_side = 'left'
+        text_input_tokens = self.llm_tokenizer(
+            text_input,
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=self.max_txt_len,
+        ).to(image.device)
+
+        self.llm_tokenizer.truncation_side = 'right'
+        text_pos_rat_tokens = self.llm_tokenizer(
+            # TODO: recheck eos_token; adding eos as pos part is not related to neg
+            [t + self.llm_tokenizer.eos_token for t in pos_descrition],
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=self.max_output_txt_len,
+        ).to(image.device)
+
+        if use_negatives:
+            text_neg_rat_tokens = self.llm_tokenizer(
+                # TODO: recheck eos_token
+                [t + self.llm_tokenizer.eos_token for t in neg_descrition],
+                return_tensors="pt",
+                padding="longest",
+                truncation=True,
+                max_length=self.max_output_txt_len,
+            ).to(image.device)
+
+        # merge tokens
+        if use_negatives:
+            llm_tokens, input_part_targets_len, sign = self.concat_input_ans_pos_neg(
+                text_input_tokens.input_ids,
+                text_input_tokens.attention_mask,
+                text_pos_rat_tokens.input_ids,
+                text_pos_rat_tokens.attention_mask,
+                text_neg_rat_tokens.input_ids,
+                text_neg_rat_tokens.attention_mask,
+            )
+        else:
+            llm_tokens, input_part_targets_len = self.concat_text_input_output(
+                text_input_tokens.input_ids,
+                text_input_tokens.attention_mask,
+                text_pos_rat_tokens.input_ids,
+                text_pos_rat_tokens.attention_mask,
+            )
+
+        # do not apply loss to the padding
+        targets = llm_tokens['input_ids'].masked_fill(
+            llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100
+        )
+
+        # do not apply loss to the text input (i.e., instruction)
+        for i, l in enumerate(input_part_targets_len):
+            targets[i][:l] = -100
+
+        # do not apply loss to the query tokens
+        empty_targets = (
+            torch.ones(atts_llm.size(), dtype=torch.long).to(
+                image.device).fill_(-100)
+        )
+        targets = torch.cat([empty_targets, targets], dim=1)
+
+        if use_negatives:
+            sign = sign.type(torch.long).to(image.device)
+            sign = torch.cat([empty_targets, sign], dim=1)
+
+        inputs_embeds = self.llm_model.get_input_embeddings()(
+            llm_tokens['input_ids'])
+        inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+        attention_mask = torch.cat(
+            [atts_llm, llm_tokens['attention_mask']], dim=1)
+
+        # no contrastive kd
+        if not use_negatives:
+            with self.maybe_autocast():
+                outputs = self.llm_model(
+                    inputs_embeds=inputs_embeds,
+                    attention_mask=attention_mask,
+                    return_dict=True,
+                    return_lm_logits_seq_out=True,
+                    labels=targets,
+                )  # loss, logits, hidden_states
+
+            loss = outputs[0]
+            return {"loss": loss, 'ce_loss': loss}
+
+        # contrastive kd
+
+        if self.kd_loss == 'ckd':
+
+            labels = targets
+            targets = targets.masked_fill(
+                sign == -1, -100
+            )
+
+            with self.maybe_autocast():
+                outputs = self.llm_model(
+                    inputs_embeds=inputs_embeds,
+                    attention_mask=attention_mask,
+                    return_dict=True,
+                    return_lm_logits_seq_out=True,
+                    labels=targets,
+                )  # loss, logits, hidden_states
+
+            pos_loss, logits, hidden_states = outputs
+
+            # contrastive loss
+
+            pos_logits = []
+            neg_targets = []
+            # sign = sign.view(bs, -1)
+            max_pos = 0
+            max_neg = 0
+
+            shift_logits = logits[..., :-1, :].contiguous()
+            logit_sign = sign[..., :-1].contiguous()  # matching with logits
+
+            shift_labels = labels[..., 1:].contiguous()
+            label_sign = sign[..., 1:].contiguous()  # matching with labels
+            for k in range(bs):
+                # logit
+                _logit_sign = logit_sign[k].cpu()
+                pos_len = (_logit_sign == 1).sum().item()
+                pos_start = _logit_sign.numpy().tolist().index(1)  # first found location
+
+                # target
+                _label_sign = label_sign[k].cpu()
+                neg_len = (_label_sign == -1).sum().item()
+                neg_start = _label_sign.numpy().tolist().index(-1)  # first found location
+
+                # logit-target
+                _len = min(pos_len, neg_len)  # TODO: check this
+                # we are avoiding zero-padding and trimming
+                _pos_logits = shift_logits[k, pos_start:pos_start+_len]
+                _neg_targets = shift_labels[k, neg_start:neg_start+_len]
+                # pad to make size same
+
+                pos_logits.append(_pos_logits)
+                neg_targets.append(_neg_targets)
+
+            pos_logits = torch.cat(pos_logits)
+            neg_targets = torch.cat(neg_targets)
+
+            neg_loss = F.nll_loss(torch.log(torch.clamp(
+                (1.0 - F.softmax(pos_logits)), min=1e-5)), neg_targets, reduction='mean')
+            loss = pos_loss * self.alpha + neg_loss * (1-self.alpha)
+
+            return {"loss": loss, 'pos_loss': pos_loss, 'neg_loss': neg_loss}
+
+    @torch.no_grad()
+    def generate(
+        self,
+        samples,
+        use_nucleus_sampling=False,
+        num_beams=5,
+        max_length=256,
+        min_length=1,
+        top_p=0.9,
+        repetition_penalty=1.5,
+        length_penalty=1,
+        num_captions=1,
+        temperature=1,
+    ):
+        self.llm_tokenizer.padding_side = "left"
+
+        if "prompt" in samples.keys():
+            prompt = samples["prompt"]
+        else:
+            prompt = self.prompt
+
+        image = samples["image"]
+
+        bs = image.size(0)
+
+        if isinstance(prompt, str):
+            prompt = [prompt] * bs
+        else:
+            assert len(
+                prompt) == bs, "The number of prompts must be equal to the batch size."
+
+        # For TextCaps
+        if "ocr_tokens" in samples.keys() and "{}" in prompt[0]:
+            prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30]))
+                      for i, p in enumerate(prompt)]
+
+        query_tokens = self.query_tokens.expand(bs, -1, -1)
+        if self.qformer_text_input:
+            # remove ocr tokens in q_former (for eval textvqa)
+            # qformer_prompt = prompt
+            # qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt]
+
+            text_Qformer = self.tokenizer(
+                prompt,
+                padding='longest',
+                truncation=True,
+                max_length=self.max_txt_len,
+                return_tensors="pt",
+            ).to(image.device)
+            query_atts = torch.ones(
+                query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+            Qformer_atts = torch.cat(
+                [query_atts, text_Qformer.attention_mask], dim=1)
+
+        # For video data
+        if image.dim() == 5:
+            inputs_llm, atts_llm = [], []
+            for j in range(image.size(2)):
+                this_frame = image[:, :, j, :, :]
+                with self.maybe_autocast():
+                    frame_embeds = self.ln_vision(
+                        self.visual_encoder(this_frame))
+                frame_atts = torch.ones(
+                    frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+                if self.qformer_text_input:
+                    frame_query_output = self.Qformer.bert(
+                        text_Qformer.input_ids,
+                        attention_mask=Qformer_atts,
+                        query_embeds=query_tokens,
+                        encoder_hidden_states=frame_embeds,
+                        encoder_attention_mask=frame_atts,
+                        return_dict=True,
+                    )
+                else:
+                    frame_query_output = self.Qformer.bert(
+                        query_embeds=query_tokens,
+                        encoder_hidden_states=frame_embeds,
+                        encoder_attention_mask=frame_atts,
+                        return_dict=True,
+                    )
+                frame_inputs_llm = self.llm_proj(
+                    frame_query_output.last_hidden_state[:, :query_tokens.size(1), :])
+                frame_atts_llm = torch.ones(frame_inputs_llm.size()[
+                                            :-1], dtype=torch.long).to(image.device)
+                inputs_llm.append(frame_inputs_llm)
+                atts_llm.append(frame_atts_llm)
+            inputs_llm = torch.cat(inputs_llm, dim=1)
+            atts_llm = torch.cat(atts_llm, dim=1)
+        else:
+            with self.maybe_autocast():
+                image_embeds = self.ln_vision(self.visual_encoder(image))
+            image_atts = torch.ones(
+                image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+            if self.qformer_text_input:
+                query_output = self.Qformer.bert(
+                    text_Qformer.input_ids,
+                    attention_mask=Qformer_atts,
+                    query_embeds=query_tokens,
+                    encoder_hidden_states=image_embeds,
+                    encoder_attention_mask=image_atts,
+                    return_dict=True,
+                )
+            else:
+                query_output = self.Qformer.bert(
+                    query_embeds=query_tokens,
+                    encoder_hidden_states=image_embeds,
+                    encoder_attention_mask=image_atts,
+                    return_dict=True,
+                )
+
+            inputs_llm = self.llm_proj(
+                query_output.last_hidden_state[:, :query_tokens.size(1), :])
+            atts_llm = torch.ones(
+                inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+
+        llm_tokens = self.llm_tokenizer(
+            prompt,
+            padding="longest",
+            return_tensors="pt"
+        ).to(image.device)
+
+        # print('EVAL INPUT: ', prompt)
+        with self.maybe_autocast():
+            inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids)
+            inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+            attention_mask = torch.cat(
+                [atts_llm, llm_tokens.attention_mask], dim=1)
+
+            outputs = self.llm_model.generate(
+                inputs_embeds=inputs_embeds,
+                attention_mask=attention_mask,
+                do_sample=use_nucleus_sampling,
+                top_p=top_p,
+                temperature=temperature,
+                num_beams=num_beams,
+                max_length=max_length,
+                min_length=min_length,
+                # eos_token_id=self.eos_token_id,
+                repetition_penalty=repetition_penalty,
+                length_penalty=length_penalty,
+                num_return_sequences=num_captions,
+            )
+
+        outputs[outputs == 0] = 2  # convert output id 0 to 2 (eos_token_id)
+        output_text = self.llm_tokenizer.batch_decode(
+            outputs, skip_special_tokens=True)
+        output_text = [text.strip() for text in output_text]
+
+        # print('EVAL OUTPUT: ', output_text)
+
+        return output_text
+
+    def predict_answers(
+        self,
+        samples,
+        num_beams=5,
+        inference_method="generate",
+        max_len=10,
+        min_len=1,
+        num_ans_candidates=128,
+        answer_list=None,
+        prompt="",
+        length_penalty=0,
+        **kwargs
+    ):
+        if isinstance(samples["text_input"], str):
+            samples["text_input"] = [samples["text_input"]]
+
+        if prompt:
+            if prompt.count("{}") == 2:
+                if 'ocr_tokens' in samples:
+                    text_input = [
+                        prompt.format(
+                            ', '.join(samples['ocr_tokens'][i][:30]), samples["text_input"][i])
+                        for i in range(len(samples["text_input"]))]
+                elif 'choices' in samples:
+                    text_input = []
+                    for i in range(len(samples["text_input"])):
+                        this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(
+                            samples["choices"][i])]
+                        this_choices = " ".join(this_choices)
+                        text_input.append(prompt.format(
+                            samples["text_input"][i], this_choices))
+            else:
+                text_input = [prompt.format(question)
+                              for question in samples["text_input"]]
+        else:
+            text_input = samples["text_input"]
+
+        samples["prompt"] = text_input
+
+        output_text = self.generate(
+            samples,
+            num_beams=num_beams,
+            max_length=max_len,
+            min_length=min_len,
+            length_penalty=length_penalty
+        )
+
+        if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
+            output_text = self._lemmatize(output_text)
+
+        return output_text
+
+    def predict_class(
+        self,
+        samples,
+        candidates,
+        n_segments=1,
+    ):
+        self.llm_tokenizer.padding_side = "left"
+
+        # If candidates is a list of lists, each sample has its candidates, then we need to iterate one by one
+        if type(candidates[0]) == list:
+            results = []
+
+            for i in range(samples["image"].size(0)):
+                this_sample = {
+                    "image": samples["image"][i].unsqueeze(0),
+                    "prompt": samples["prompt"],
+                }
+
+                if "text_input" in samples.keys():
+                    this_sample["text_input"] = [samples["text_input"][i]]
+
+                if 'context' in samples.keys():
+                    this_sample['context'] = [samples["context"][i]]
+
+                if 'history' in samples.keys():
+                    this_sample['history'] = [samples["history"][i]]
+
+                if 'caption' in samples.keys():
+                    this_sample['caption'] = [samples["caption"][i]]
+
+                this_result = self._predict_class(
+                    this_sample, candidates[i], n_segments)
+                results.append(this_result)
+
+            try:
+                results = torch.cat(results, dim=0)
+            except:
+                results = [res.tolist()[0] for res in results]
+
+            return results
+
+        return self._predict_class(samples, candidates, n_segments)
+
+    def _predict_class(
+        self,
+        samples,
+        candidates,
+        n_segments=1,
+    ):
+        image = samples["image"]
+        prompt = samples["prompt"]
+
+        bs = image.size(0)
+
+        if isinstance(prompt, str):
+            prompt = [prompt] * bs
+        else:
+            assert len(
+                prompt) == bs, "The number of prompts must be equal to the batch size."
+
+        if "text_input" in samples.keys():
+            if type(samples["text_input"][0]) == list:
+                prompt = [prompt[i].format(*samples["text_input"][i])
+                          for i in range(len(prompt))]
+            else:
+                prompt = [prompt[i].format(samples["text_input"][i])
+                          for i in range(len(prompt))]
+
+        # scienceqa
+        if 'context' in samples.keys() and samples['context'] != '':
+            prompt = [
+                f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))]
+
+        # visual dialog
+        if 'history' in samples.keys() and samples['history'][0] != '':
+            prompt = [f'dialog history: {samples["history"][i]}\n{prompt[i]}' for i in range(
+                len(prompt))]
+
+        if 'caption' in samples.keys() and samples['caption'][0] != '':
+            prompt = [f'This image has the caption "{samples["caption"][i]}". {prompt[i]}' for i in range(
+                len(prompt))]
+
+        query_tokens = self.query_tokens.expand(bs, -1, -1)
+        if self.qformer_text_input:
+            text_Qformer = self.tokenizer(
+                prompt,
+                padding='longest',
+                truncation=True,
+                max_length=self.max_txt_len,
+                return_tensors="pt"
+            ).to(image.device)
+            query_atts = torch.ones(
+                query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+            Qformer_atts = torch.cat(
+                [query_atts, text_Qformer.attention_mask], dim=1)
+
+        if image.dim() == 5:
+            inputs_llm, atts_llm = [], []
+            for j in range(image.size(2)):
+                this_frame = image[:, :, j, :, :]
+                with self.maybe_autocast():
+                    frame_embeds = self.ln_vision(
+                        self.visual_encoder(this_frame))
+                    frame_atts = torch.ones(
+                        frame_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+                if self.qformer_text_input:
+                    frame_query_output = self.Qformer.bert(
+                        text_Qformer.input_ids,
+                        attention_mask=Qformer_atts,
+                        query_embeds=query_tokens,
+                        encoder_hidden_states=frame_embeds,
+                        encoder_attention_mask=frame_atts,
+                        return_dict=True,
+                    )
+                else:
+                    frame_query_output = self.Qformer.bert(
+                        query_embeds=query_tokens,
+                        encoder_hidden_states=frame_embeds,
+                        encoder_attention_mask=frame_atts,
+                        return_dict=True,
+                    )
+
+                frame_inputs_llm = self.llm_proj(
+                    frame_query_output.last_hidden_state[:, :query_tokens.size(1), :])
+                frame_atts_llm = torch.ones(frame_inputs_llm.size()[
+                                            :-1], dtype=torch.long).to(image.device)
+                inputs_llm.append(frame_inputs_llm)
+                atts_llm.append(frame_atts_llm)
+            inputs_llm = torch.cat(inputs_llm, dim=1)
+            atts_llm = torch.cat(atts_llm, dim=1)
+        else:
+            with self.maybe_autocast():
+                image_embeds = self.ln_vision(self.visual_encoder(image))
+            image_atts = torch.ones(
+                image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+            if self.qformer_text_input:
+                query_output = self.Qformer.bert(
+                    text_Qformer.input_ids,
+                    attention_mask=Qformer_atts,
+                    query_embeds=query_tokens,
+                    encoder_hidden_states=image_embeds,
+                    encoder_attention_mask=image_atts,
+                    return_dict=True,
+                )
+            else:
+                query_output = self.Qformer.bert(
+                    query_embeds=query_tokens,
+                    encoder_hidden_states=image_embeds,
+                    encoder_attention_mask=image_atts,
+                    return_dict=True,
+                )
+
+            inputs_llm = self.llm_proj(
+                query_output.last_hidden_state[:, :query_tokens.size(1), :])
+            atts_llm = torch.ones(
+                inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+
+        self.llm_tokenizer.padding_side = "right"
+        self.llm_tokenizer.truncation_side = 'left'
+        text_input_tokens = self.llm_tokenizer(
+            prompt,
+            return_tensors="pt",
+            padding="longest",
+            # truncation=True,
+            # max_length=self.max_txt_len,
+        ).to(image.device)
+
+        empty_targets = torch.ones(atts_llm.size(), dtype=torch.long).to(
+            image.device).fill_(-100)
+
+        # self.llm_tokenizer.padding_side = "right"
+        self.llm_tokenizer.truncation_side = 'right'
+        n_cands = len(candidates)
+        with self.maybe_autocast(dtype=torch.bfloat16):
+            all_losses = []
+            for n in range(n_segments):
+                seg_len = n_cands // n_segments
+                if n == (n_segments - 1):
+                    seg_len = n_cands - seg_len * (n_segments - 1)
+
+                start_i = n * (n_cands // n_segments)
+                end_i = start_i + seg_len
+
+                this_output_tokens = self.llm_tokenizer(
+                    candidates[start_i:end_i],
+                    return_tensors="pt",
+                    padding="longest",
+                    # truncation=True,
+                    # max_length=self.max_output_txt_len,
+                ).to(image.device)
+
+                this_input_tokens_ids = text_input_tokens.input_ids.repeat_interleave(
+                    seg_len, dim=0)
+                this_input_tokens_atts = text_input_tokens.attention_mask.repeat_interleave(
+                    seg_len, dim=0)
+
+                this_output_tokens_ids = this_output_tokens.input_ids.repeat(
+                    bs, 1)
+                this_output_tokens_atts = this_output_tokens.attention_mask.repeat(
+                    bs, 1)
+
+                this_llm_tokens, this_input_targets_len = self.concat_text_input_output(
+                    this_input_tokens_ids,
+                    this_input_tokens_atts,
+                    this_output_tokens_ids,
+                    this_output_tokens_atts
+                )
+
+                this_llm_input_ids = this_llm_tokens['input_ids']
+                this_llm_atts = this_llm_tokens['attention_mask']
+                # this_llm_input_ids = torch.cat([this_input_tokens_ids, this_output_tokens_ids], dim=1)
+                # this_llm_atts = torch.cat([this_input_tokens_atts, this_output_tokens_atts], dim=1)
+
+                inputs_embeds = self.llm_model.get_input_embeddings()(this_llm_input_ids)
+                inputs_embeds = torch.cat(
+                    [inputs_llm.repeat_interleave(seg_len, dim=0), inputs_embeds], dim=1)
+                attention_mask = torch.cat(
+                    [atts_llm.repeat_interleave(seg_len, dim=0), this_llm_atts], dim=1)
+
+                this_targets = this_llm_input_ids.masked_fill(
+                    this_llm_input_ids == self.llm_tokenizer.pad_token_id, -100)
+                # this_targets[:, :this_input_tokens_ids.size(1)] = -100
+                for i, l in enumerate(this_input_targets_len):
+                    this_targets[i][:l] = -100
+
+                this_targets = torch.cat(
+                    [empty_targets.repeat_interleave(seg_len, dim=0), this_targets], dim=1)
+
+                outputs = self.llm_model(
+                    inputs_embeds=inputs_embeds,
+                    attention_mask=attention_mask,
+                    return_dict=True,
+                    labels=this_targets,
+                    reduction="none",
+                )
+
+                loss = outputs.loss
+
+                loss = loss.reshape(bs, seg_len)
+                # output_class_ranks = torch.argsort(loss, dim=-1)
+                all_losses.append(loss)
+
+            all_losses = torch.cat(all_losses, dim=-1)
+            output_class_ranks = torch.argsort(all_losses, dim=-1)
+
+        return output_class_ranks
+
+    def _lemmatize(self, answers):
+        def apply(answer):
+            doc = self.lemmatizer(answer)
+
+            words = []
+            for token in doc:
+                if token.pos_ in ["NOUN", "VERB"]:
+                    words.append(token.lemma_)
+                else:
+                    words.append(token.text)
+            answer = " ".join(words)
+
+            return answer
+
+        return [apply(answer) for answer in answers]
+
+    @property
+    def lemmatizer(self):
+        if self._lemmatizer is None:
+            try:
+                import spacy
+
+                self._lemmatizer = spacy.load("en_core_web_sm")
+            except ImportError:
+                logging.error(
+                    """
+                    Please install spacy and en_core_web_sm model to apply lemmatization.
+                    python -m spacy download en_core_web_sm
+                    OR
+                    import spacy.cli
+                    spacy.cli.download("en_core_web_sm")
+                    """
+                )
+                exit(1)
+
+        return self._lemmatizer
+
+    @classmethod
+    def from_config(cls, cfg):
+        vit_model = cfg.get("vit_model", "eva_clip_g")
+        img_size = cfg.get("image_size")
+        num_query_token = cfg.get("num_query_token")
+        llm_model = cfg.get("llm_model")
+
+        drop_path_rate = cfg.get("drop_path_rate", 0)
+        use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
+        vit_precision = cfg.get("vit_precision", "fp16")
+        freeze_vit = cfg.get("freeze_vit", True)
+
+        prompt = cfg.get("prompt", "")
+        max_txt_len = cfg.get("max_txt_len", 128)
+        max_output_txt_len = cfg.get("max_output_txt_len", 256)
+
+        apply_lemmatizer = cfg.get("apply_lemmatizer", False)
+
+        qformer_text_input = cfg.get("qformer_text_input", True)
+
+        alpha = cfg.get("alpha", 1)  # no ckd
+        kd_loss = cfg.get("kd_loss", 'kd')  # no ckd
+        llm_lora_r = cfg.get("llm_lora_r", 8)
+        llm_lora_apply = cfg.get("llm_lora_apply", "attn")
+
+        model = cls(
+            vit_model=vit_model,
+            img_size=img_size,
+            drop_path_rate=drop_path_rate,
+            use_grad_checkpoint=use_grad_checkpoint,
+            vit_precision=vit_precision,
+            freeze_vit=freeze_vit,
+            num_query_token=num_query_token,
+            llm_model=llm_model,
+            prompt=prompt,
+            max_txt_len=max_txt_len,
+            max_output_txt_len=max_output_txt_len,
+            apply_lemmatizer=apply_lemmatizer,
+            qformer_text_input=qformer_text_input,
+            alpha=alpha,
+            kd_loss=kd_loss,
+            llm_lora_r=llm_lora_r,
+            llm_lora_apply=llm_lora_apply
+        )
+
+        # if qformer_text_input:
+        #     # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal)
+        #     model.load_from_pretrained(
+        #         url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth"
+        #     )
+
+        model.load_checkpoint_from_config(cfg)
+
+        return model
+
+    def load_from_pretrained(self, url_or_filename):
+        if is_url(url_or_filename):
+            cached_file = download_cached_file(
+                url_or_filename, check_hash=False, progress=True
+            )
+            checkpoint = torch.load(cached_file, map_location="cpu")
+        elif os.path.isfile(url_or_filename):
+            checkpoint = torch.load(url_or_filename, map_location="cpu")
+        else:
+            raise RuntimeError("checkpoint url or path is invalid")
+
+        if "model" in checkpoint:
+            state_dict = checkpoint["model"]
+        else:
+            state_dict = checkpoint
+
+        # strict=False for peft layers
+        msg = self.load_state_dict(state_dict, strict=False)
+
+        # logging.info("Missing keys {}".format(msg.missing_keys))
+        logging.info("load checkpoint from %s" % url_or_filename)
+
+        return msg
diff --git a/lavis/models/blip2_models/modeling_llama.py b/lavis/models/blip2_models/modeling_llama.py
index 8f1b2fa..d7ac195 100644
--- a/lavis/models/blip2_models/modeling_llama.py
+++ b/lavis/models/blip2_models/modeling_llama.py
@@ -687,6 +687,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
         reduction: Optional[str] = "mean",
+        return_lm_logits_seq_out: Optional[bool] = False  # added by PS
     ) -> Union[Tuple, CausalLMOutputWithPast]:
         r"""
         Args:
@@ -756,6 +757,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
             output = (logits,) + outputs[1:]
             return (loss,) + output if loss is not None else output
 
+        if return_lm_logits_seq_out:  # added by PS
+            if loss is not None:
+                return loss, logits, hidden_states
+            else:
+                return logits, hidden_states
+
         return CausalLMOutputWithPast(
             loss=loss,
             logits=logits,
diff --git a/lavis/models/blip2_models/modeling_t5.py b/lavis/models/blip2_models/modeling_t5.py
index 069cbb6..4adbba8 100644
--- a/lavis/models/blip2_models/modeling_t5.py
+++ b/lavis/models/blip2_models/modeling_t5.py
@@ -1775,6 +1775,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
         reduction: Optional[str] = "mean",
+        return_lm_logits_seq_out: Optional[bool] = False  # added by PS
     ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
         r"""
         labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1908,6 +1909,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
         if not return_dict:
             output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
             return ((loss,) + output) if loss is not None else output
+        if return_lm_logits_seq_out:  # added by PS
+            if loss is not None:
+                return loss, lm_logits, sequence_output
+            else:
+                return lm_logits, sequence_output
 
         return Seq2SeqLMOutput(
             loss=loss,
diff --git a/lavis/models/clip_vit.py b/lavis/models/clip_vit.py
index 372c10c..120812b 100644
--- a/lavis/models/clip_vit.py
+++ b/lavis/models/clip_vit.py
@@ -276,7 +276,8 @@ def create_clip_vit_L(img_size=224, use_checkpoint=False, precision="fp16"):
     interpolate_pos_embed(model, state_dict)
 
     incompatible_keys = model.load_state_dict(state_dict, strict=False)
-    # print(incompatible_keys)
+    print("loaded vision encoder with pretrained weights")
+    print(incompatible_keys)
 
     if precision == "fp16":
         convert_weights_to_fp16(model)
diff --git a/lavis/models/eva_vit.py b/lavis/models/eva_vit.py
index 213d765..08b7d62 100644
--- a/lavis/models/eva_vit.py
+++ b/lavis/models/eva_vit.py
@@ -489,7 +489,8 @@ def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, pre
     interpolate_pos_embed(model, state_dict)
 
     incompatible_keys = model.load_state_dict(state_dict, strict=False)
-#     print(incompatible_keys)
+    print("loaded vision encoder with pretrained weights")
+    print(incompatible_keys)
 
     if precision == "fp16":
         #         model.to("cuda")
diff --git a/lavis/processors/__init__.py b/lavis/processors/__init__.py
index 84d897c..364820a 100644
--- a/lavis/processors/__init__.py
+++ b/lavis/processors/__init__.py
@@ -16,6 +16,7 @@ from lavis.processors.blip_processors import (
     Blip2ImageTrainProcessor,
     BlipImageEvalProcessor,
     BlipCaptionProcessor,
+    BlipCaptionProcessorCKD,
 )
 from lavis.processors.blip_diffusion_processors import (
     BlipDiffusionInputImageProcessor,
@@ -30,6 +31,7 @@ from lavis.processors.audio_processors import BeatsAudioProcessor
 from lavis.processors.ulip_processors import ULIPPCProcessor
 from lavis.processors.instruction_text_processors import BlipInstructionProcessor
 
+
 from lavis.common.registry import registry
 
 __all__ = [
@@ -42,6 +44,7 @@ __all__ = [
     "Blip2ImageTrainProcessor",
     "BlipImageEvalProcessor",
     "BlipCaptionProcessor",
+    "BlipCaptionProcessorCKD",
     "BlipInstructionProcessor",
     # BLIP-Diffusion
     "BlipDiffusionInputImageProcessor",
diff --git a/lavis/processors/blip_processors.py b/lavis/processors/blip_processors.py
index 58fa21b..387f44a 100644
--- a/lavis/processors/blip_processors.py
+++ b/lavis/processors/blip_processors.py
@@ -68,6 +68,49 @@ class BlipCaptionProcessor(BaseProcessor):
         return caption
 
 
+@registry.register_processor("blip_caption_ckd")
+class BlipCaptionProcessorCKD(BaseProcessor):
+    def __init__(self, prompt="", max_words=50):
+        self.prompt = prompt
+        self.max_words = max_words
+
+    def __call__(self, caption):
+        caption = self.prompt + self.pre_caption(caption)
+
+        return caption
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        prompt = cfg.get("prompt", "")
+        max_words = cfg.get("max_words", 50)
+
+        return cls(prompt=prompt, max_words=max_words)
+
+    def pre_caption(self, caption):
+        caption = re.sub(
+            r"([!\"()*#:;~])",
+            " ",
+            caption,
+        )
+        caption = re.sub(
+            r"\s{2,}",
+            " ",
+            caption,
+        )
+        caption = caption.rstrip("\n")
+        caption = caption.strip(" ")
+
+        # truncate caption
+        caption_words = caption.split(" ")
+        if len(caption_words) > self.max_words:
+            caption = " ".join(caption_words[: self.max_words])
+
+        return caption
+
+
 @registry.register_processor("blip_question")
 class BlipQuestionProcessor(BaseProcessor):
     def __init__(self, max_words=50):
diff --git a/lavis/tasks/vqa.py b/lavis/tasks/vqa.py
index 9d0a26d..4c8b047 100644
--- a/lavis/tasks/vqa.py
+++ b/lavis/tasks/vqa.py
@@ -114,7 +114,7 @@ class VQATask(BaseTask):
                         if dist_utils.get_rank() == 0:
                             os.makedirs(os.path.join(registry.get_path(
                                 "cache_root"), f'{ds_name}_gt'), exist_ok=True)
-                            try:
+                            try:  # FIXME: check why the downloaded anno is diff than their's
                                 convert_to_coco_gt(
                                     dataset, self.ques_files[split], self.anno_files[split], split, self.sample_id_key)
                             except:
